Files
ailine/app/rag/tools.py

90 lines
2.5 KiB
Python
Raw Normal View History

2026-04-18 16:31:48 +08:00
"""
2026-04-19 22:01:55 +08:00
RAG 工具模块
2026-04-18 16:31:48 +08:00
2026-04-19 22:01:55 +08:00
将检索功能封装为 LangChain Tool Agent 调用
2026-04-18 16:31:48 +08:00
"""
2026-04-19 22:01:55 +08:00
from langchain_core.tools import tool
from rag_core import LlamaCppEmbedder, QDRANT_URL, QDRANT_API_KEY
from .pipeline import RAGPipeline, RAGLevel
2026-04-18 16:31:48 +08:00
2026-04-19 22:01:55 +08:00
@tool
async def search_knowledge_base(query: str, rag_level: str = "rerank") -> str:
"""在知识库中搜索与查询相关的文档片段。
2026-04-18 16:31:48 +08:00
2026-04-19 22:01:55 +08:00
适用于事实性问题背景知识查询
2026-04-18 16:31:48 +08:00
Args:
2026-04-19 22:01:55 +08:00
query: 查询字符串
rag_level: 检索级别可选值basic基础向量检索rerank基础检索+重排序fusionRAG-Fusion
2026-04-18 16:31:48 +08:00
Returns:
2026-04-19 22:01:55 +08:00
检索到的相关文档内容
2026-04-18 16:31:48 +08:00
"""
2026-04-19 22:01:55 +08:00
# 初始化嵌入模型
embedder = LlamaCppEmbedder()
embeddings = embedder.as_langchain_embeddings()
2026-04-18 16:31:48 +08:00
2026-04-19 22:01:55 +08:00
# 创建 RAG 流水线
pipeline = RAGPipeline(
2026-04-18 16:31:48 +08:00
embeddings=embeddings,
2026-04-19 22:01:55 +08:00
config={
"rag_level": rag_level,
"collection_name": "rag_documents",
"rerank_top_n": 5,
}
2026-04-18 16:31:48 +08:00
)
2026-04-19 22:01:55 +08:00
# 执行检索
try:
documents = await pipeline.aretrieve(query)
if not documents:
return "未找到相关信息。"
# 格式化结果
context = pipeline.format_context(documents)
return context
except Exception as e:
return f"检索过程中发生错误: {str(e)}"
2026-04-18 16:31:48 +08:00
2026-04-19 22:01:55 +08:00
@tool
def search_knowledge_base_sync(query: str, rag_level: str = "rerank") -> str:
"""同步版本的知识库搜索工具。
适用于事实性问题背景知识查询
2026-04-18 16:31:48 +08:00
Args:
2026-04-19 22:01:55 +08:00
query: 查询字符串
rag_level: 检索级别可选值basic基础向量检索rerank基础检索+重排序fusionRAG-Fusion
2026-04-18 16:31:48 +08:00
Returns:
2026-04-19 22:01:55 +08:00
检索到的相关文档内容
2026-04-18 16:31:48 +08:00
"""
2026-04-19 22:01:55 +08:00
# 初始化嵌入模型
embedder = LlamaCppEmbedder()
embeddings = embedder.as_langchain_embeddings()
2026-04-18 16:31:48 +08:00
2026-04-19 22:01:55 +08:00
# 创建 RAG 流水线
pipeline = RAGPipeline(
2026-04-18 16:31:48 +08:00
embeddings=embeddings,
2026-04-19 22:01:55 +08:00
config={
"rag_level": rag_level,
"collection_name": "rag_documents",
"rerank_top_n": 5,
}
2026-04-18 16:31:48 +08:00
)
2026-04-19 22:01:55 +08:00
# 执行检索
try:
documents = pipeline.retrieve(query)
if not documents:
return "未找到相关信息。"
# 格式化结果
context = pipeline.format_context(documents)
return context
except Exception as e:
return f"检索过程中发生错误: {str(e)}"