90 lines
2.5 KiB
Python
90 lines
2.5 KiB
Python
"""
|
||
RAG 工具模块
|
||
|
||
将检索功能封装为 LangChain Tool,供 Agent 调用。
|
||
"""
|
||
|
||
from langchain_core.tools import tool
|
||
from rag_core import LlamaCppEmbedder, QDRANT_URL, QDRANT_API_KEY
|
||
from .pipeline import RAGPipeline, RAGLevel
|
||
|
||
|
||
@tool
|
||
async def search_knowledge_base(query: str, rag_level: str = "rerank") -> str:
|
||
"""在知识库中搜索与查询相关的文档片段。
|
||
|
||
适用于事实性问题、背景知识查询。
|
||
|
||
Args:
|
||
query: 查询字符串
|
||
rag_level: 检索级别,可选值:basic(基础向量检索)、rerank(基础检索+重排序)、fusion(RAG-Fusion)
|
||
|
||
Returns:
|
||
检索到的相关文档内容
|
||
"""
|
||
# 初始化嵌入模型
|
||
embedder = LlamaCppEmbedder()
|
||
embeddings = embedder.as_langchain_embeddings()
|
||
|
||
# 创建 RAG 流水线
|
||
pipeline = RAGPipeline(
|
||
embeddings=embeddings,
|
||
config={
|
||
"rag_level": rag_level,
|
||
"collection_name": "rag_documents",
|
||
"rerank_top_n": 5,
|
||
}
|
||
)
|
||
|
||
# 执行检索
|
||
try:
|
||
documents = await pipeline.aretrieve(query)
|
||
if not documents:
|
||
return "未找到相关信息。"
|
||
|
||
# 格式化结果
|
||
context = pipeline.format_context(documents)
|
||
return context
|
||
except Exception as e:
|
||
return f"检索过程中发生错误: {str(e)}"
|
||
|
||
|
||
@tool
|
||
def search_knowledge_base_sync(query: str, rag_level: str = "rerank") -> str:
|
||
"""同步版本的知识库搜索工具。
|
||
|
||
适用于事实性问题、背景知识查询。
|
||
|
||
Args:
|
||
query: 查询字符串
|
||
rag_level: 检索级别,可选值:basic(基础向量检索)、rerank(基础检索+重排序)、fusion(RAG-Fusion)
|
||
|
||
Returns:
|
||
检索到的相关文档内容
|
||
"""
|
||
# 初始化嵌入模型
|
||
embedder = LlamaCppEmbedder()
|
||
embeddings = embedder.as_langchain_embeddings()
|
||
|
||
# 创建 RAG 流水线
|
||
pipeline = RAGPipeline(
|
||
embeddings=embeddings,
|
||
config={
|
||
"rag_level": rag_level,
|
||
"collection_name": "rag_documents",
|
||
"rerank_top_n": 5,
|
||
}
|
||
)
|
||
|
||
# 执行检索
|
||
try:
|
||
documents = pipeline.retrieve(query)
|
||
if not documents:
|
||
return "未找到相关信息。"
|
||
|
||
# 格式化结果
|
||
context = pipeline.format_context(documents)
|
||
return context
|
||
except Exception as e:
|
||
return f"检索过程中发生错误: {str(e)}"
|