refactor: 重构RAG核心组件,简化代码结构和测试文件
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 6m53s

This commit is contained in:
2026-05-04 17:58:10 +08:00
parent a07e398739
commit 9841f47432
31 changed files with 578 additions and 1496 deletions

View File

@@ -37,9 +37,9 @@ def _get_bool(key: str) -> bool | None:
# ========== 第三方 API 密钥 ==========
ZHIPUAI_API_KEY=_get_str("ZHIPUAI_API_KEY")
DEEPSEEK_API_KEY=_get_str("DEEPSEEK_API_KEY")
SILICONFLOW_API_KEY=_get_str("SILICONFLOW_API_KEY")
ZHIPUAI_API_KEY = _get_str("ZHIPUAI_API_KEY")
DEEPSEEK_API_KEY = _get_str("DEEPSEEK_API_KEY")
SILICONFLOW_API_KEY = _get_str("SILICONFLOW_API_KEY")
# ========== 智谱 API 配置 ==========
@@ -69,7 +69,7 @@ LOCAL_MODEL_NAME = _get_str("LOCAL_MODEL_NAME") or "gemma-4-E4B-it"
# ========== llama.cpp 服务配置URL + API密钥 配对) ==========
# 主 LLM 服务
VLLM_BASE_URL = _get_str("VLLM_BASE_URL")
LLM_API_KEY = _get_str("LLAMACPP_API_KEY")
LLM_API_KEY = _get_str("LLM_API_KEY")
# Embedding 服务 (用于 Mem0 的向量化)
LLAMACPP_EMBEDDING_URL = _get_str("LLAMACPP_EMBEDDING_URL")
@@ -78,6 +78,26 @@ LLAMACPP_API_KEY = _get_str("LLAMACPP_API_KEY")
# Reranker 服务
LLAMACPP_RERANKER_URL = _get_str("LLAMACPP_RERANKER_URL")
# ========== 小模型配置(查询改写、意图分类等简单任务) ==========
# 默认复用大模型配置,后续可单独配置
# 本地小模型(默认复用 VLLM 配置)
SMALL_VLLM_BASE_URL = _get_str("SMALL_VLLM_BASE_URL")
SMALL_LLM_API_KEY = _get_str("SMALL_LLM_API_KEY")
SMALL_LOCAL_MODEL_NAME = _get_str("SMALL_LOCAL_MODEL_NAME") or LOCAL_MODEL_NAME
# 如果小模型没单独配置,用大模型的配置
if not SMALL_VLLM_BASE_URL:
SMALL_VLLM_BASE_URL = VLLM_BASE_URL
if not SMALL_LLM_API_KEY:
SMALL_LLM_API_KEY = LLM_API_KEY
# DeepSeek 小模型(默认复用 DeepSeek 配置)
SMALL_DEEPSEEK_API_KEY = _get_str("SMALL_DEEPSEEK_API_KEY")
SMALL_DEEPSEEK_MODEL = _get_str("SMALL_DEEPSEEK_MODEL") or "deepseek-chat"
SMALL_DEEPSEEK_API_BASE = _get_str("SMALL_DEEPSEEK_API_BASE") or "https://api.deepseek.com"
# 如果小模型没单独配置,用大模型的配置
if not SMALL_DEEPSEEK_API_KEY:
SMALL_DEEPSEEK_API_KEY = DEEPSEEK_API_KEY
# ========== Qdrant 向量数据库配置URL + API密钥 配对) ==========
QDRANT_URL = _get_str("QDRANT_URL")
@@ -114,4 +134,4 @@ ENABLE_GRAPH_TRACE = _get_bool("ENABLE_GRAPH_TRACE")
# ========== 日志配置 ==========
LOG_LEVEL = _get_str("LOG_LEVEL")
DEBUG = _get_bool("DEBUG")
DEBUG = _get_bool("DEBUG")

View File

@@ -6,7 +6,7 @@ from typing import Optional, Dict, Any
import sys
import os
from backend.app.model_services.chat_services import get_chat_service
from backend.app.model_services.chat_services import get_small_llm_service
class IntentType(Enum):
@@ -33,7 +33,7 @@ class IntentClassifier:
"""意图分类器"""
def __init__(self):
self.llm = get_chat_service()
self.llm = get_small_llm_service()
self._intent_examples = self._build_examples()
def _build_examples(self) -> str:

View File

@@ -19,7 +19,7 @@ from app.main_graph.utils.retry_utils import (
)
# 真正导入和利用已有 RAG 代码
from app.rag.tools import create_rag_tool_sync
from app.rag.tools import create_rag_tool
from app.rag.pipeline import RAGPipeline

View File

@@ -1,6 +1,6 @@
# app/rag_initializer.py
from app.rag.tools import create_rag_tool_sync, create_rag_tool_async
from rag_core import create_parent_retriever
from app.rag.tools import create_rag_tool
from app.rag.retriever import create_parent_hybrid_retriever
from app.model_services import get_embedding_service
from app.logger import info, warning
@@ -10,18 +10,18 @@ async def init_rag_tool(local_llm_creator):
info("🔄 正在初始化 RAG 检索系统...")
# 使用统一的嵌入服务获取接口
embeddings = get_embedding_service()
retriever = create_parent_retriever(
retriever = create_parent_hybrid_retriever(
collection_name="rag_documents",
search_k=5,
embeddings=embeddings
)
rewrite_llm = local_llm_creator()
rag_tool = create_rag_tool_async(
rag_tool = create_rag_tool(
retriever, rewrite_llm,
num_queries=3, rerank_top_n=5
)
info("✅ RAG 检索工具初始化成功(异步版本)")
info("✅ RAG 检索工具初始化成功(异步版本)")
return rag_tool
except Exception as e:
warning(f"⚠️ RAG 检索工具初始化失败: {e}")
return None
return None

View File

@@ -6,9 +6,11 @@
from .embedding_services import get_embedding_service
from .rerank_services import get_rerank_service, BaseRerankService
from .chat_services import get_small_llm_service
__all__ = [
"get_embedding_service",
"get_rerank_service",
"get_small_llm_service",
"BaseRerankService"
]

View File

@@ -219,51 +219,98 @@ class DeepSeekChatProvider(BaseServiceProvider[BaseChatModel]):
# ========== 轻量级模型 Provider ==========
class ZhipuSmallModelProvider(BaseServiceProvider[BaseChatModel]):
class LocalSmallModelProvider(BaseServiceProvider[BaseChatModel]):
"""
智谱 AI 轻量级模型服务提供者(用于意图分类等简单任务)
使用 glm-5.1-flash 或其他小模型
本地轻量级模型服务提供者(用于查询改写、意图分类等简单任务)
使用小模型独立配置
"""
def __init__(self, model: str = "glm-5.1-flash"):
super().__init__("zhipu_small")
self._model = model
def __init__(self, model: str = None):
from app.config import SMALL_LOCAL_MODEL_NAME, SMALL_VLLM_BASE_URL, SMALL_LLM_API_KEY
super().__init__("local_small")
self._model = model or SMALL_LOCAL_MODEL_NAME
self._base_url = SMALL_VLLM_BASE_URL
self._api_key = SMALL_LLM_API_KEY
def is_available(self) -> bool:
"""检查智谱轻量模型服务是否可用"""
if not ZHIPUAI_API_KEY:
logger.warning("ZHIPUAI_API_KEY 未配置,轻量模型不可用")
"""检查本地小模型服务是否可用"""
if not self._base_url:
logger.warning("SMALL_VLLM_BASE_URL 未配置,本地小模型不可用")
return False
try:
# 先测试主机名能否解析
import httpx
from urllib.parse import urlparse
parsed_url = urlparse(self._base_url)
host = parsed_url.hostname
port = parsed_url.port or (80 if parsed_url.scheme == 'http' else 443)
# 测试能否建立 TCP 连接(快速失败)
import socket
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(2.0)
try:
sock.connect((host, port))
sock.close()
except Exception as e:
logger.warning(f"本地小模型服务无法连接: {host}:{port} - {e}")
return False
# 再尝试调用简单的 API
client = httpx.Client(base_url=self._base_url.rstrip('/'), timeout=5.0)
headers = {}
if self._api_key:
headers["Authorization"] = f"Bearer {self._api_key}"
try:
response = client.get("/models", headers=headers)
if response.status_code == 200:
logger.info(f"本地小模型服务可用: {self._model}")
return True
except Exception:
pass
logger.warning(f"本地小模型服务响应异常")
return False
except Exception as e:
logger.warning(f"本地小模型服务不可用: {e}")
return False
logger.info(f"智谱轻量模型配置正确: {self._model}")
return True
def get_service(self) -> BaseChatModel:
"""获取智谱轻量模型服务"""
"""获取本地小模型服务"""
if self._service_instance is None:
from langchain_community.chat_models import ChatZhipuAI
self._service_instance = ChatZhipuAI(
from langchain_openai import ChatOpenAI
from pydantic import SecretStr
self._service_instance = ChatOpenAI(
base_url=self._base_url,
api_key=SecretStr(self._api_key) if self._api_key else SecretStr(""),
model=self._model,
api_key=ZHIPUAI_API_KEY,
temperature=0.1,
max_tokens=2048,
timeout=30.0,
max_retries=2,
streaming=False
streaming=False,
)
return self._service_instance
class DeepSeekSmallModelProvider(BaseServiceProvider[BaseChatModel]):
"""
DeepSeek 轻量级模型服务提供者(备选
DeepSeek 轻量级模型服务提供者(用于查询改写、意图分类等简单任务
使用小模型独立配置
"""
def __init__(self, model: str = "deepseek-chat"):
def __init__(self, model: str = None):
from app.config import SMALL_DEEPSEEK_MODEL, SMALL_DEEPSEEK_API_KEY, SMALL_DEEPSEEK_API_BASE
super().__init__("deepseek_small")
self._model = model
self._model = model or SMALL_DEEPSEEK_MODEL
self._api_key = SMALL_DEEPSEEK_API_KEY
self._api_base = SMALL_DEEPSEEK_API_BASE
def is_available(self) -> bool:
if not DEEPSEEK_API_KEY:
logger.warning("DEEPSEEK_API_KEY 未配置")
if not self._api_key:
logger.warning("SMALL_DEEPSEEK_API_KEY 未配置")
return False
logger.info(f"DeepSeek 轻量模型配置正确: {self._model}")
return True
@@ -274,8 +321,8 @@ class DeepSeekSmallModelProvider(BaseServiceProvider[BaseChatModel]):
from pydantic import SecretStr
self._service_instance = ChatOpenAI(
base_url="https://api.deepseek.com",
api_key=SecretStr(DEEPSEEK_API_KEY),
base_url=self._api_base,
api_key=SecretStr(self._api_key),
model=self._model,
temperature=0.1,
max_tokens=2048,
@@ -339,20 +386,17 @@ def get_all_chat_services() -> Dict[str, BaseChatModel]:
def get_small_llm_service() -> BaseChatModel:
"""
获取轻量级大模型服务(用于意图分类等简单任务)
优先顺序: zhipu_small -> deepseek_small -> (降级到 get_chat_service)
获取轻量级大模型服务(用于查询改写、意图分类等简单任务)
优先顺序: 本地模型 -> DeepSeek 小模型
⚠️ 注意:小模型任务不降级到大模型,避免不必要的 token 消耗!
Returns:
BaseChatModel: LangChain 兼容的 ChatModel 实例
"""
def _create_small_chain():
primary = ZhipuSmallModelProvider()
primary = LocalSmallModelProvider()
fallbacks = [DeepSeekSmallModelProvider()]
return FallbackServiceChain(primary, fallbacks)
try:
chain = SingletonServiceManager.get_or_create("small_llm_chain", _create_small_chain)
return chain.get_available_service()
except Exception as e:
logger.warning(f"轻量模型初始化失败,降级到默认大模型: {e}")
return get_chat_service()
chain = SingletonServiceManager.get_or_create("small_llm_chain", _create_small_chain)
return chain.get_available_service()

View File

@@ -42,7 +42,7 @@ from .rerank import DocumentReranker, create_document_reranker
from .query_transform import MultiQueryGenerator
from .fusion import reciprocal_rank_fusion
from .pipeline import RAGPipeline
from .tools import create_rag_tool_sync
from .tools import create_rag_tool
__all__ = [
@@ -64,5 +64,5 @@ __all__ = [
"RAGPipeline",
# 工具创建(供 Agent 使用)
"create_rag_tool_sync",
"create_rag_tool",
]

View File

@@ -13,7 +13,7 @@ from typing import List, Optional
from langchain_core.documents import Document
from langchain_core.language_models import BaseLanguageModel
from app.model_services import get_rerank_service
from app.model_services import get_rerank_service, get_small_llm_service
from app.rag.rerank import create_document_reranker
from app.rag.query_transform import MultiQueryGenerator
from app.rag.fusion import reciprocal_rank_fusion
@@ -31,7 +31,7 @@ class RAGPipeline:
def __init__(
self,
retriever=None,
llm: Optional[BaseLanguageModel] = None,
llm: Optional[BaseLanguageModel] = "default_small",
num_queries: int = 3,
rerank_top_n: int = 5,
collection_name: str = "rag_documents",
@@ -41,6 +41,9 @@ class RAGPipeline:
retriever: 基础检索器对象,需实现 ainvoke(query) 异步方法。
如果不提供,会自动创建默认的父子文档混合检索器。
llm: 用于生成多路查询的语言模型。
- "default_small": (默认) 使用小模型(本地 + DeepSeek
- None / False: 不做查询改写
- BaseLanguageModel 实例: 自定义模型
num_queries: 生成的查询变体数量。
rerank_top_n: 最终返回的文档数量。
collection_name: Qdrant 集合名称(仅当 retriever 未提供时使用)。
@@ -53,13 +56,26 @@ class RAGPipeline:
)
else:
self.retriever = retriever
# 处理 llm 参数
if llm == "default_small":
try:
self.llm = get_small_llm_service()
except Exception as e:
import logging
logger = logging.getLogger(__name__)
logger.warning(f"小模型初始化失败,将不做查询改写: {e}")
self.llm = None
elif llm in (None, False):
self.llm = None
else:
self.llm = llm
self.llm = llm
self.num_queries = num_queries
self.rerank_top_n = rerank_top_n
# 初始化组件 - 使用统一的重排服务获取接口
self.query_generator = MultiQueryGenerator(llm=llm, num_queries=num_queries) if llm else None
self.query_generator = MultiQueryGenerator(llm=self.llm, num_queries=num_queries) if self.llm else None
self.reranker = create_document_reranker()
async def aretrieve(self, query: str) -> List[Document]:
@@ -102,11 +118,7 @@ class RAGPipeline:
final_docs = fused_docs[:self.rerank_top_n]
return final_docs
def retrieve(self, query: str) -> List[Document]:
"""同步检索入口(内部调用异步方法)"""
return asyncio.run(self.aretrieve(query))
def format_context(self, documents: List[Document]) -> str:
"""
将文档列表格式化为上下文字符串
@@ -129,7 +141,7 @@ class RAGPipeline:
def create_rag_pipeline(
collection_name: str = "rag_documents",
llm: Optional[BaseLanguageModel] = None,
llm: Optional[BaseLanguageModel] = "default_small",
num_queries: int = 3,
rerank_top_n: int = 5,
) -> RAGPipeline:
@@ -138,7 +150,10 @@ def create_rag_pipeline(
Args:
collection_name: Qdrant 集合名称
llm: 用于生成多路查询的语言模型
llm: 用于生成多路查询的语言模型
- "default_small": (默认) 使用小模型(本地 + DeepSeek
- None / False: 不做查询改写
- BaseLanguageModel 实例: 自定义模型
num_queries: 生成的查询变体数量
rerank_top_n: 最终返回的文档数量

View File

@@ -33,16 +33,16 @@ DEFAULT_PARENT_SEARCH_K = 5
class HybridRetriever(BaseRetriever):
"""
混合检索器:稠密向量 + BM25 稀疏向量 RRF 分数融合(异步)
使用 Qdrant Universal Query API (query_points)
"""
collection_name: str = Field(description="Qdrant 集合名称")
search_k: int = Field(default=DEFAULT_SEARCH_K, description="检索返回结果数")
_vector_store: Any = PrivateAttr()
_client: Any = PrivateAttr()
_sparse_embedder: Any = PrivateAttr()
def __init__(
self,
collection_name: str,
@@ -62,21 +62,39 @@ class HybridRetriever(BaseRetriever):
self._vector_store = vector_store
self._client = vector_store.get_async_qdrant_client()
self._sparse_embedder = get_sparse_embedder()
def _get_relevant_documents(
self, query: str, *, run_manager: Any = None
) -> List[Document]:
"""
同步检索(不推荐使用,仅供兼容性)
注意:在异步环境中请使用 _aget_relevant_documents 或 ainvoke
"""
import asyncio
try:
loop = asyncio.get_running_loop()
# 已有事件循环,使用 create_task
task = loop.create_task(self._aget_relevant_documents(query))
return loop.run_until_complete(task)
except RuntimeError:
# 没有事件循环,创建新的
return asyncio.run(self._aget_relevant_documents(query))
async def _aget_relevant_documents(
self, query: str, **kwargs
self, query: str, *, run_manager: Any = None
) -> List[Document]:
"""
异步混合检索相关文档
"""
# 1. 生成查询向量
dense_query = await self._vector_store._aembed_query(query)
dense_query = await self._vector_store.aembed_query(query)
sparse_query = self._sparse_embedder.embed_query(query)
sparse_vec = models.SparseVector(
indices=sparse_query["indices"],
values=sparse_query["values"]
)
# 2. 使用 Qdrant 的 query_points API
response = await self._client.query_points(
collection_name=self.collection_name,
@@ -96,7 +114,7 @@ class HybridRetriever(BaseRetriever):
limit=self.search_k,
with_payload=True
)
# 3. 转换结果
results = []
for point in response.points:
@@ -105,28 +123,28 @@ class HybridRetriever(BaseRetriever):
metadata=point.payload
)
results.append(doc)
debug(f"混合检索返回 %d 个文档", len(results))
debug(f"混合检索返回 {len(results)} 个文档")
return results
class ParentHybridRetriever(BaseRetriever):
"""
父子文档混合检索器(异步):
1. 先用混合检索找到相关子文档
2. 根据子文档的 parent_id 找到对应的父文档
3. 去重并返回父文档
"""
collection_name: str = Field(description="Qdrant 集合名称")
search_k: int = Field(default=DEFAULT_PARENT_SEARCH_K, description="检索返回结果数")
_vector_store: Any = PrivateAttr()
_client: Any = PrivateAttr()
_sparse_embedder: Any = PrivateAttr()
_docstore: Any = PrivateAttr()
def __init__(
self,
collection_name: str,
@@ -149,24 +167,40 @@ class ParentHybridRetriever(BaseRetriever):
self._client = vector_store.get_async_qdrant_client()
self._sparse_embedder = get_sparse_embedder()
self._docstore = docstore
def _get_relevant_documents(
self, query: str, *, run_manager: Any = None
) -> List[Document]:
"""
同步检索(不推荐使用,仅供兼容性)
注意:在异步环境中请使用 _aget_relevant_documents 或 ainvoke
"""
import asyncio
try:
loop = asyncio.get_running_loop()
task = loop.create_task(self._aget_relevant_documents(query))
return loop.run_until_complete(task)
except RuntimeError:
return asyncio.run(self._aget_relevant_documents(query))
async def _aget_relevant_documents(
self, query: str, **kwargs
self, query: str, *, run_manager: Any = None
) -> List[Document]:
"""
异步检索相关父文档
"""
# 1. 生成查询向量
dense_query = await self._vector_store._aembed_query(query)
dense_query = await self._vector_store.aembed_query(query)
sparse_query = self._sparse_embedder.embed_query(query)
sparse_vec = models.SparseVector(
indices=sparse_query["indices"],
values=sparse_query["values"]
)
# 2. 多取一些子文档,避免去重后数量不足
search_limit = self.search_k * 2
# 3. 使用 query_points API 进行混合检索
response = await self._client.query_points(
collection_name=self.collection_name,
@@ -186,30 +220,30 @@ class ParentHybridRetriever(BaseRetriever):
limit=search_limit,
with_payload=True
)
if not response.points:
debug("混合检索未找到任何文档")
return []
# 4. 收集 parent_id 和对应最高得分
parent_score_map = {}
parent_ids = set()
child_point_map = {} # 保存子文档点用于降级
for point in response.points:
payload_copy = point.payload.copy()
parent_id = payload_copy.get("parent_id", point.id)
score = point.score
if parent_id not in parent_score_map or score > parent_score_map[parent_id]:
parent_score_map[parent_id] = score
parent_ids.add(parent_id)
child_point_map[parent_id] = point
# 5. 批量查询父文档
parent_docs = []
found_parent_ids = set()
# 先尝试从 Qdrant 直接查询(如果父文档也在 Qdrant 中)
try:
parent_points = await self._client.retrieve(
@@ -217,7 +251,7 @@ class ParentHybridRetriever(BaseRetriever):
ids=list(parent_ids),
with_payload=True
)
for point in parent_points:
payload_copy = point.payload.copy()
doc = Document(
@@ -226,10 +260,10 @@ class ParentHybridRetriever(BaseRetriever):
)
parent_docs.append(doc)
found_parent_ids.add(point.id)
except Exception as e:
warning(f"从 Qdrant 查询父文档失败: %s", e)
warning(f"从 Qdrant 查询父文档失败: {e}")
# 6. 如果有 docstore尝试从 docstore 查询剩余的父文档
if self._docstore and len(found_parent_ids) < len(parent_ids):
missing_parent_ids = parent_ids - found_parent_ids
@@ -240,12 +274,12 @@ class ParentHybridRetriever(BaseRetriever):
parent_docs.append(doc)
found_parent_ids.add(doc_id)
except Exception as e:
warning(f"从 docstore 查询父文档失败: %s", e)
warning(f"从 docstore 查询父文档失败: {e}")
# 7. 降级:对于仍未找到的父文档,用子文档本身代替
missing_parent_ids = parent_ids - found_parent_ids
if missing_parent_ids:
warning(f"以下 parent_id 未找到对应的父文档,将返回子文档本身: %s", missing_parent_ids)
warning(f"以下 parent_id 未找到对应的父文档,将返回子文档本身: {missing_parent_ids}")
for parent_id in missing_parent_ids:
child_point = child_point_map.get(parent_id)
if child_point:
@@ -255,17 +289,17 @@ class ParentHybridRetriever(BaseRetriever):
metadata=payload_copy
)
parent_docs.append(doc)
# 8. 按照得分降序排序,返回前 k 个
parent_docs_with_scores = [
(doc, parent_score_map.get(doc.metadata.get("id", doc.id if hasattr(doc, "id") else ""), 0.0))
for doc in parent_docs
]
parent_docs_with_scores.sort(key=lambda x: x[1], reverse=True)
final_docs = [doc for doc, _ in parent_docs_with_scores[:self.search_k]]
debug(f"父子文档混合检索返回 %d 个父文档", len(final_docs))
debug(f"父子文档混合检索返回 {len(final_docs)} 个父文档")
return final_docs
@@ -291,7 +325,7 @@ def create_hybrid_retriever(
embeddings = get_embedding_service()
info("使用统一嵌入服务(本地 llama.cpp → 智谱 API 自动降级)")
vector_store = QdrantHybridStore(collection_name=collection_name, embeddings=embeddings)
vector_store = QdrantHybridStore(collection_name=collection_name)
try:
vector_store.get_client().get_collection(collection_name)
@@ -336,7 +370,7 @@ def create_parent_hybrid_retriever(
embeddings = get_embedding_service()
info("使用统一嵌入服务(本地 llama.cpp → 智谱 API 自动降级)")
vector_store = QdrantHybridStore(collection_name=collection_name, embeddings=embeddings)
vector_store = QdrantHybridStore(collection_name=collection_name)
try:
vector_store.get_client().get_collection(collection_name)

View File

@@ -1,5 +1,5 @@
"""
RAG 工具模块
RAG 工具模块(完全异步)
将检索功能封装为 LangChain Tool供 Agent 调用。
采用固定流水线:多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档。
@@ -13,78 +13,24 @@ from langchain_core.retrievers import BaseRetriever
from app.rag.pipeline import RAGPipeline, create_rag_pipeline
def create_rag_tool_sync(
def create_rag_tool(
retriever: Optional[BaseRetriever] = None,
llm: Optional[BaseLanguageModel] = None,
llm: Optional[BaseLanguageModel] = "default_small",
num_queries: int = 3,
rerank_top_n: int = 5,
collection_name: str = "rag_documents",
) -> Callable:
"""
创建一个配置好的 RAG 检索工具(同步版本)。
创建一个配置好的 RAG 检索工具(完全异步)。
默认使用混合检索(稠密+BM25稀疏+ 父子文档模式。
Args:
retriever: 基础检索器对象(可选,不提供则自动创建)
llm: 用于生成多路查询的语言模型(可选)
num_queries: 生成的查询变体数量
rerank_top_n: 最终返回的文档数量
collection_name: Qdrant 集合名称
Returns:
LangChain Tool 函数
"""
pipeline = RAGPipeline(
retriever=retriever,
llm=llm,
num_queries=num_queries,
rerank_top_n=rerank_top_n,
collection_name=collection_name,
)
@tool
def search_knowledge_base_sync(query: str) -> str:
"""
在知识库中搜索与查询相关的文档片段。
使用混合检索(稠密向量语义 + BM25 关键词)+ 父子文档模式,
检索效果最优。
Args:
query: 用户提出的问题或查询字符串
Returns:
格式化后的相关文档内容
"""
try:
documents = pipeline.retrieve(query)
if not documents:
return f"在知识库 '{collection_name}' 中未找到与 '{query}' 相关的信息。"
context = pipeline.format_context(documents)
return context
except Exception as e:
return f"检索过程中发生错误: {str(e)}"
return search_knowledge_base_sync
def create_rag_tool_async(
retriever: Optional[BaseRetriever] = None,
llm: Optional[BaseLanguageModel] = None,
num_queries: int = 3,
rerank_top_n: int = 5,
collection_name: str = "rag_documents",
) -> Callable:
"""
创建一个配置好的 RAG 检索工具(异步版本)。
默认使用混合检索(稠密+BM25稀疏+ 父子文档模式。
Args:
retriever: 基础检索器对象(可选,不提供则自动创建)
llm: 用于生成多路查询的语言模型(可选)
llm: 用于生成多路查询的语言模型
- "default_small": (默认) 使用小模型(本地 + DeepSeek
- None / False: 不做查询改写
- BaseLanguageModel 实例: 自定义模型
num_queries: 生成的查询变体数量
rerank_top_n: 最终返回的文档数量
collection_name: Qdrant 集合名称
@@ -101,9 +47,9 @@ def create_rag_tool_async(
)
@tool
async def search_knowledge_base_async(query: str) -> str:
async def search_knowledge_base(query: str) -> str:
"""
在知识库中搜索与查询相关的文档片段(异步版本)。
在知识库中搜索与查询相关的文档片段(完全异步)。
使用混合检索(稠密向量语义 + BM25 关键词)+ 父子文档模式,
检索效果最优。
@@ -124,30 +70,4 @@ def create_rag_tool_async(
except Exception as e:
return f"检索过程中发生错误: {str(e)}"
return search_knowledge_base_async
def create_rag_tool(
collection_name: str = "rag_documents",
llm: Optional[BaseLanguageModel] = None,
num_queries: int = 3,
rerank_top_n: int = 5,
) -> Callable:
"""
创建 RAG 检索工具的便捷函数(同步版本)。
Args:
collection_name: Qdrant 集合名称
llm: 用于生成多路查询的语言模型(可选)
num_queries: 生成的查询变体数量
rerank_top_n: 最终返回的文档数量
Returns:
LangChain Tool 函数
"""
return create_rag_tool_sync(
collection_name=collection_name,
llm=llm,
num_queries=num_queries,
rerank_top_n=rerank_top_n,
)
return search_knowledge_base