2026-04-18 16:31:48 +08:00
|
|
|
|
"""
|
2026-04-19 22:01:55 +08:00
|
|
|
|
RAG 工具模块
|
2026-04-18 16:31:48 +08:00
|
|
|
|
|
2026-04-19 22:01:55 +08:00
|
|
|
|
将检索功能封装为 LangChain Tool,供 Agent 调用。
|
2026-04-18 16:31:48 +08:00
|
|
|
|
"""
|
|
|
|
|
|
|
2026-04-19 22:01:55 +08:00
|
|
|
|
from langchain_core.tools import tool
|
|
|
|
|
|
from rag_core import LlamaCppEmbedder, QDRANT_URL, QDRANT_API_KEY
|
|
|
|
|
|
from .pipeline import RAGPipeline, RAGLevel
|
2026-04-18 16:31:48 +08:00
|
|
|
|
|
|
|
|
|
|
|
2026-04-19 22:01:55 +08:00
|
|
|
|
@tool
|
|
|
|
|
|
async def search_knowledge_base(query: str, rag_level: str = "rerank") -> str:
|
|
|
|
|
|
"""在知识库中搜索与查询相关的文档片段。
|
2026-04-18 16:31:48 +08:00
|
|
|
|
|
2026-04-19 22:01:55 +08:00
|
|
|
|
适用于事实性问题、背景知识查询。
|
2026-04-18 16:31:48 +08:00
|
|
|
|
|
|
|
|
|
|
Args:
|
2026-04-19 22:01:55 +08:00
|
|
|
|
query: 查询字符串
|
|
|
|
|
|
rag_level: 检索级别,可选值:basic(基础向量检索)、rerank(基础检索+重排序)、fusion(RAG-Fusion)
|
2026-04-18 16:31:48 +08:00
|
|
|
|
|
|
|
|
|
|
Returns:
|
2026-04-19 22:01:55 +08:00
|
|
|
|
检索到的相关文档内容
|
2026-04-18 16:31:48 +08:00
|
|
|
|
"""
|
2026-04-19 22:01:55 +08:00
|
|
|
|
# 初始化嵌入模型
|
|
|
|
|
|
embedder = LlamaCppEmbedder()
|
|
|
|
|
|
embeddings = embedder.as_langchain_embeddings()
|
2026-04-18 16:31:48 +08:00
|
|
|
|
|
2026-04-19 22:01:55 +08:00
|
|
|
|
# 创建 RAG 流水线
|
|
|
|
|
|
pipeline = RAGPipeline(
|
2026-04-18 16:31:48 +08:00
|
|
|
|
embeddings=embeddings,
|
2026-04-19 22:01:55 +08:00
|
|
|
|
config={
|
|
|
|
|
|
"rag_level": rag_level,
|
|
|
|
|
|
"collection_name": "rag_documents",
|
|
|
|
|
|
"rerank_top_n": 5,
|
|
|
|
|
|
}
|
2026-04-18 16:31:48 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
2026-04-19 22:01:55 +08:00
|
|
|
|
# 执行检索
|
|
|
|
|
|
try:
|
|
|
|
|
|
documents = await pipeline.aretrieve(query)
|
|
|
|
|
|
if not documents:
|
|
|
|
|
|
return "未找到相关信息。"
|
|
|
|
|
|
|
|
|
|
|
|
# 格式化结果
|
|
|
|
|
|
context = pipeline.format_context(documents)
|
|
|
|
|
|
return context
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
return f"检索过程中发生错误: {str(e)}"
|
2026-04-18 16:31:48 +08:00
|
|
|
|
|
|
|
|
|
|
|
2026-04-19 22:01:55 +08:00
|
|
|
|
@tool
|
|
|
|
|
|
def search_knowledge_base_sync(query: str, rag_level: str = "rerank") -> str:
|
|
|
|
|
|
"""同步版本的知识库搜索工具。
|
|
|
|
|
|
|
|
|
|
|
|
适用于事实性问题、背景知识查询。
|
2026-04-18 16:31:48 +08:00
|
|
|
|
|
|
|
|
|
|
Args:
|
2026-04-19 22:01:55 +08:00
|
|
|
|
query: 查询字符串
|
|
|
|
|
|
rag_level: 检索级别,可选值:basic(基础向量检索)、rerank(基础检索+重排序)、fusion(RAG-Fusion)
|
|
|
|
|
|
|
2026-04-18 16:31:48 +08:00
|
|
|
|
Returns:
|
2026-04-19 22:01:55 +08:00
|
|
|
|
检索到的相关文档内容
|
2026-04-18 16:31:48 +08:00
|
|
|
|
"""
|
2026-04-19 22:01:55 +08:00
|
|
|
|
# 初始化嵌入模型
|
|
|
|
|
|
embedder = LlamaCppEmbedder()
|
|
|
|
|
|
embeddings = embedder.as_langchain_embeddings()
|
2026-04-18 16:31:48 +08:00
|
|
|
|
|
2026-04-19 22:01:55 +08:00
|
|
|
|
# 创建 RAG 流水线
|
|
|
|
|
|
pipeline = RAGPipeline(
|
2026-04-18 16:31:48 +08:00
|
|
|
|
embeddings=embeddings,
|
2026-04-19 22:01:55 +08:00
|
|
|
|
config={
|
|
|
|
|
|
"rag_level": rag_level,
|
|
|
|
|
|
"collection_name": "rag_documents",
|
|
|
|
|
|
"rerank_top_n": 5,
|
|
|
|
|
|
}
|
2026-04-18 16:31:48 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
2026-04-19 22:01:55 +08:00
|
|
|
|
# 执行检索
|
|
|
|
|
|
try:
|
|
|
|
|
|
documents = pipeline.retrieve(query)
|
|
|
|
|
|
if not documents:
|
|
|
|
|
|
return "未找到相关信息。"
|
|
|
|
|
|
|
|
|
|
|
|
# 格式化结果
|
|
|
|
|
|
context = pipeline.format_context(documents)
|
|
|
|
|
|
return context
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
return f"检索过程中发生错误: {str(e)}"
|