refactor: 重构RAG核心组件,简化代码结构和测试文件
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 6m53s
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 6m53s
This commit is contained in:
@@ -42,7 +42,7 @@ from .rerank import DocumentReranker, create_document_reranker
|
||||
from .query_transform import MultiQueryGenerator
|
||||
from .fusion import reciprocal_rank_fusion
|
||||
from .pipeline import RAGPipeline
|
||||
from .tools import create_rag_tool_sync
|
||||
from .tools import create_rag_tool
|
||||
|
||||
|
||||
__all__ = [
|
||||
@@ -64,5 +64,5 @@ __all__ = [
|
||||
"RAGPipeline",
|
||||
|
||||
# 工具创建(供 Agent 使用)
|
||||
"create_rag_tool_sync",
|
||||
"create_rag_tool",
|
||||
]
|
||||
@@ -13,7 +13,7 @@ from typing import List, Optional
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
|
||||
from app.model_services import get_rerank_service
|
||||
from app.model_services import get_rerank_service, get_small_llm_service
|
||||
from app.rag.rerank import create_document_reranker
|
||||
from app.rag.query_transform import MultiQueryGenerator
|
||||
from app.rag.fusion import reciprocal_rank_fusion
|
||||
@@ -31,7 +31,7 @@ class RAGPipeline:
|
||||
def __init__(
|
||||
self,
|
||||
retriever=None,
|
||||
llm: Optional[BaseLanguageModel] = None,
|
||||
llm: Optional[BaseLanguageModel] = "default_small",
|
||||
num_queries: int = 3,
|
||||
rerank_top_n: int = 5,
|
||||
collection_name: str = "rag_documents",
|
||||
@@ -41,6 +41,9 @@ class RAGPipeline:
|
||||
retriever: 基础检索器对象,需实现 ainvoke(query) 异步方法。
|
||||
如果不提供,会自动创建默认的父子文档混合检索器。
|
||||
llm: 用于生成多路查询的语言模型。
|
||||
- "default_small": (默认) 使用小模型(本地 + DeepSeek)
|
||||
- None / False: 不做查询改写
|
||||
- BaseLanguageModel 实例: 自定义模型
|
||||
num_queries: 生成的查询变体数量。
|
||||
rerank_top_n: 最终返回的文档数量。
|
||||
collection_name: Qdrant 集合名称(仅当 retriever 未提供时使用)。
|
||||
@@ -53,13 +56,26 @@ class RAGPipeline:
|
||||
)
|
||||
else:
|
||||
self.retriever = retriever
|
||||
|
||||
# 处理 llm 参数
|
||||
if llm == "default_small":
|
||||
try:
|
||||
self.llm = get_small_llm_service()
|
||||
except Exception as e:
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(f"小模型初始化失败,将不做查询改写: {e}")
|
||||
self.llm = None
|
||||
elif llm in (None, False):
|
||||
self.llm = None
|
||||
else:
|
||||
self.llm = llm
|
||||
|
||||
self.llm = llm
|
||||
self.num_queries = num_queries
|
||||
self.rerank_top_n = rerank_top_n
|
||||
|
||||
# 初始化组件 - 使用统一的重排服务获取接口
|
||||
self.query_generator = MultiQueryGenerator(llm=llm, num_queries=num_queries) if llm else None
|
||||
self.query_generator = MultiQueryGenerator(llm=self.llm, num_queries=num_queries) if self.llm else None
|
||||
self.reranker = create_document_reranker()
|
||||
|
||||
async def aretrieve(self, query: str) -> List[Document]:
|
||||
@@ -102,11 +118,7 @@ class RAGPipeline:
|
||||
final_docs = fused_docs[:self.rerank_top_n]
|
||||
|
||||
return final_docs
|
||||
|
||||
def retrieve(self, query: str) -> List[Document]:
|
||||
"""同步检索入口(内部调用异步方法)"""
|
||||
return asyncio.run(self.aretrieve(query))
|
||||
|
||||
|
||||
def format_context(self, documents: List[Document]) -> str:
|
||||
"""
|
||||
将文档列表格式化为上下文字符串
|
||||
@@ -129,7 +141,7 @@ class RAGPipeline:
|
||||
|
||||
def create_rag_pipeline(
|
||||
collection_name: str = "rag_documents",
|
||||
llm: Optional[BaseLanguageModel] = None,
|
||||
llm: Optional[BaseLanguageModel] = "default_small",
|
||||
num_queries: int = 3,
|
||||
rerank_top_n: int = 5,
|
||||
) -> RAGPipeline:
|
||||
@@ -138,7 +150,10 @@ def create_rag_pipeline(
|
||||
|
||||
Args:
|
||||
collection_name: Qdrant 集合名称
|
||||
llm: 用于生成多路查询的语言模型
|
||||
llm: 用于生成多路查询的语言模型。
|
||||
- "default_small": (默认) 使用小模型(本地 + DeepSeek)
|
||||
- None / False: 不做查询改写
|
||||
- BaseLanguageModel 实例: 自定义模型
|
||||
num_queries: 生成的查询变体数量
|
||||
rerank_top_n: 最终返回的文档数量
|
||||
|
||||
|
||||
@@ -33,16 +33,16 @@ DEFAULT_PARENT_SEARCH_K = 5
|
||||
class HybridRetriever(BaseRetriever):
|
||||
"""
|
||||
混合检索器:稠密向量 + BM25 稀疏向量 RRF 分数融合(异步)
|
||||
|
||||
|
||||
使用 Qdrant Universal Query API (query_points)
|
||||
"""
|
||||
collection_name: str = Field(description="Qdrant 集合名称")
|
||||
search_k: int = Field(default=DEFAULT_SEARCH_K, description="检索返回结果数")
|
||||
|
||||
|
||||
_vector_store: Any = PrivateAttr()
|
||||
_client: Any = PrivateAttr()
|
||||
_sparse_embedder: Any = PrivateAttr()
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: str,
|
||||
@@ -62,21 +62,39 @@ class HybridRetriever(BaseRetriever):
|
||||
self._vector_store = vector_store
|
||||
self._client = vector_store.get_async_qdrant_client()
|
||||
self._sparse_embedder = get_sparse_embedder()
|
||||
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: Any = None
|
||||
) -> List[Document]:
|
||||
"""
|
||||
同步检索(不推荐使用,仅供兼容性)
|
||||
|
||||
注意:在异步环境中请使用 _aget_relevant_documents 或 ainvoke
|
||||
"""
|
||||
import asyncio
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
# 已有事件循环,使用 create_task
|
||||
task = loop.create_task(self._aget_relevant_documents(query))
|
||||
return loop.run_until_complete(task)
|
||||
except RuntimeError:
|
||||
# 没有事件循环,创建新的
|
||||
return asyncio.run(self._aget_relevant_documents(query))
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self, query: str, **kwargs
|
||||
self, query: str, *, run_manager: Any = None
|
||||
) -> List[Document]:
|
||||
"""
|
||||
异步混合检索相关文档
|
||||
"""
|
||||
# 1. 生成查询向量
|
||||
dense_query = await self._vector_store._aembed_query(query)
|
||||
dense_query = await self._vector_store.aembed_query(query)
|
||||
sparse_query = self._sparse_embedder.embed_query(query)
|
||||
sparse_vec = models.SparseVector(
|
||||
indices=sparse_query["indices"],
|
||||
values=sparse_query["values"]
|
||||
)
|
||||
|
||||
|
||||
# 2. 使用 Qdrant 的 query_points API
|
||||
response = await self._client.query_points(
|
||||
collection_name=self.collection_name,
|
||||
@@ -96,7 +114,7 @@ class HybridRetriever(BaseRetriever):
|
||||
limit=self.search_k,
|
||||
with_payload=True
|
||||
)
|
||||
|
||||
|
||||
# 3. 转换结果
|
||||
results = []
|
||||
for point in response.points:
|
||||
@@ -105,28 +123,28 @@ class HybridRetriever(BaseRetriever):
|
||||
metadata=point.payload
|
||||
)
|
||||
results.append(doc)
|
||||
|
||||
debug(f"混合检索返回 %d 个文档", len(results))
|
||||
|
||||
debug(f"混合检索返回 {len(results)} 个文档")
|
||||
return results
|
||||
|
||||
|
||||
class ParentHybridRetriever(BaseRetriever):
|
||||
"""
|
||||
父子文档混合检索器(异步):
|
||||
|
||||
|
||||
1. 先用混合检索找到相关子文档
|
||||
2. 根据子文档的 parent_id 找到对应的父文档
|
||||
3. 去重并返回父文档
|
||||
"""
|
||||
|
||||
|
||||
collection_name: str = Field(description="Qdrant 集合名称")
|
||||
search_k: int = Field(default=DEFAULT_PARENT_SEARCH_K, description="检索返回结果数")
|
||||
|
||||
|
||||
_vector_store: Any = PrivateAttr()
|
||||
_client: Any = PrivateAttr()
|
||||
_sparse_embedder: Any = PrivateAttr()
|
||||
_docstore: Any = PrivateAttr()
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: str,
|
||||
@@ -149,24 +167,40 @@ class ParentHybridRetriever(BaseRetriever):
|
||||
self._client = vector_store.get_async_qdrant_client()
|
||||
self._sparse_embedder = get_sparse_embedder()
|
||||
self._docstore = docstore
|
||||
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: Any = None
|
||||
) -> List[Document]:
|
||||
"""
|
||||
同步检索(不推荐使用,仅供兼容性)
|
||||
|
||||
注意:在异步环境中请使用 _aget_relevant_documents 或 ainvoke
|
||||
"""
|
||||
import asyncio
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
task = loop.create_task(self._aget_relevant_documents(query))
|
||||
return loop.run_until_complete(task)
|
||||
except RuntimeError:
|
||||
return asyncio.run(self._aget_relevant_documents(query))
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self, query: str, **kwargs
|
||||
self, query: str, *, run_manager: Any = None
|
||||
) -> List[Document]:
|
||||
"""
|
||||
异步检索相关父文档
|
||||
"""
|
||||
# 1. 生成查询向量
|
||||
dense_query = await self._vector_store._aembed_query(query)
|
||||
dense_query = await self._vector_store.aembed_query(query)
|
||||
sparse_query = self._sparse_embedder.embed_query(query)
|
||||
sparse_vec = models.SparseVector(
|
||||
indices=sparse_query["indices"],
|
||||
values=sparse_query["values"]
|
||||
)
|
||||
|
||||
|
||||
# 2. 多取一些子文档,避免去重后数量不足
|
||||
search_limit = self.search_k * 2
|
||||
|
||||
|
||||
# 3. 使用 query_points API 进行混合检索
|
||||
response = await self._client.query_points(
|
||||
collection_name=self.collection_name,
|
||||
@@ -186,30 +220,30 @@ class ParentHybridRetriever(BaseRetriever):
|
||||
limit=search_limit,
|
||||
with_payload=True
|
||||
)
|
||||
|
||||
|
||||
if not response.points:
|
||||
debug("混合检索未找到任何文档")
|
||||
return []
|
||||
|
||||
|
||||
# 4. 收集 parent_id 和对应最高得分
|
||||
parent_score_map = {}
|
||||
parent_ids = set()
|
||||
child_point_map = {} # 保存子文档点用于降级
|
||||
|
||||
|
||||
for point in response.points:
|
||||
payload_copy = point.payload.copy()
|
||||
parent_id = payload_copy.get("parent_id", point.id)
|
||||
score = point.score
|
||||
|
||||
|
||||
if parent_id not in parent_score_map or score > parent_score_map[parent_id]:
|
||||
parent_score_map[parent_id] = score
|
||||
parent_ids.add(parent_id)
|
||||
child_point_map[parent_id] = point
|
||||
|
||||
|
||||
# 5. 批量查询父文档
|
||||
parent_docs = []
|
||||
found_parent_ids = set()
|
||||
|
||||
|
||||
# 先尝试从 Qdrant 直接查询(如果父文档也在 Qdrant 中)
|
||||
try:
|
||||
parent_points = await self._client.retrieve(
|
||||
@@ -217,7 +251,7 @@ class ParentHybridRetriever(BaseRetriever):
|
||||
ids=list(parent_ids),
|
||||
with_payload=True
|
||||
)
|
||||
|
||||
|
||||
for point in parent_points:
|
||||
payload_copy = point.payload.copy()
|
||||
doc = Document(
|
||||
@@ -226,10 +260,10 @@ class ParentHybridRetriever(BaseRetriever):
|
||||
)
|
||||
parent_docs.append(doc)
|
||||
found_parent_ids.add(point.id)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
warning(f"从 Qdrant 查询父文档失败: %s", e)
|
||||
|
||||
warning(f"从 Qdrant 查询父文档失败: {e}")
|
||||
|
||||
# 6. 如果有 docstore,尝试从 docstore 查询剩余的父文档
|
||||
if self._docstore and len(found_parent_ids) < len(parent_ids):
|
||||
missing_parent_ids = parent_ids - found_parent_ids
|
||||
@@ -240,12 +274,12 @@ class ParentHybridRetriever(BaseRetriever):
|
||||
parent_docs.append(doc)
|
||||
found_parent_ids.add(doc_id)
|
||||
except Exception as e:
|
||||
warning(f"从 docstore 查询父文档失败: %s", e)
|
||||
|
||||
warning(f"从 docstore 查询父文档失败: {e}")
|
||||
|
||||
# 7. 降级:对于仍未找到的父文档,用子文档本身代替
|
||||
missing_parent_ids = parent_ids - found_parent_ids
|
||||
if missing_parent_ids:
|
||||
warning(f"以下 parent_id 未找到对应的父文档,将返回子文档本身: %s", missing_parent_ids)
|
||||
warning(f"以下 parent_id 未找到对应的父文档,将返回子文档本身: {missing_parent_ids}")
|
||||
for parent_id in missing_parent_ids:
|
||||
child_point = child_point_map.get(parent_id)
|
||||
if child_point:
|
||||
@@ -255,17 +289,17 @@ class ParentHybridRetriever(BaseRetriever):
|
||||
metadata=payload_copy
|
||||
)
|
||||
parent_docs.append(doc)
|
||||
|
||||
|
||||
# 8. 按照得分降序排序,返回前 k 个
|
||||
parent_docs_with_scores = [
|
||||
(doc, parent_score_map.get(doc.metadata.get("id", doc.id if hasattr(doc, "id") else ""), 0.0))
|
||||
for doc in parent_docs
|
||||
]
|
||||
parent_docs_with_scores.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
|
||||
final_docs = [doc for doc, _ in parent_docs_with_scores[:self.search_k]]
|
||||
debug(f"父子文档混合检索返回 %d 个父文档", len(final_docs))
|
||||
|
||||
debug(f"父子文档混合检索返回 {len(final_docs)} 个父文档")
|
||||
|
||||
return final_docs
|
||||
|
||||
|
||||
@@ -291,7 +325,7 @@ def create_hybrid_retriever(
|
||||
embeddings = get_embedding_service()
|
||||
info("使用统一嵌入服务(本地 llama.cpp → 智谱 API 自动降级)")
|
||||
|
||||
vector_store = QdrantHybridStore(collection_name=collection_name, embeddings=embeddings)
|
||||
vector_store = QdrantHybridStore(collection_name=collection_name)
|
||||
|
||||
try:
|
||||
vector_store.get_client().get_collection(collection_name)
|
||||
@@ -336,7 +370,7 @@ def create_parent_hybrid_retriever(
|
||||
embeddings = get_embedding_service()
|
||||
info("使用统一嵌入服务(本地 llama.cpp → 智谱 API 自动降级)")
|
||||
|
||||
vector_store = QdrantHybridStore(collection_name=collection_name, embeddings=embeddings)
|
||||
vector_store = QdrantHybridStore(collection_name=collection_name)
|
||||
|
||||
try:
|
||||
vector_store.get_client().get_collection(collection_name)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
RAG 工具模块
|
||||
RAG 工具模块(完全异步)
|
||||
|
||||
将检索功能封装为 LangChain Tool,供 Agent 调用。
|
||||
采用固定流水线:多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档。
|
||||
@@ -13,78 +13,24 @@ from langchain_core.retrievers import BaseRetriever
|
||||
from app.rag.pipeline import RAGPipeline, create_rag_pipeline
|
||||
|
||||
|
||||
def create_rag_tool_sync(
|
||||
def create_rag_tool(
|
||||
retriever: Optional[BaseRetriever] = None,
|
||||
llm: Optional[BaseLanguageModel] = None,
|
||||
llm: Optional[BaseLanguageModel] = "default_small",
|
||||
num_queries: int = 3,
|
||||
rerank_top_n: int = 5,
|
||||
collection_name: str = "rag_documents",
|
||||
) -> Callable:
|
||||
"""
|
||||
创建一个配置好的 RAG 检索工具(同步版本)。
|
||||
创建一个配置好的 RAG 检索工具(完全异步)。
|
||||
|
||||
默认使用混合检索(稠密+BM25稀疏)+ 父子文档模式。
|
||||
|
||||
Args:
|
||||
retriever: 基础检索器对象(可选,不提供则自动创建)
|
||||
llm: 用于生成多路查询的语言模型(可选)
|
||||
num_queries: 生成的查询变体数量
|
||||
rerank_top_n: 最终返回的文档数量
|
||||
collection_name: Qdrant 集合名称
|
||||
|
||||
Returns:
|
||||
LangChain Tool 函数
|
||||
"""
|
||||
pipeline = RAGPipeline(
|
||||
retriever=retriever,
|
||||
llm=llm,
|
||||
num_queries=num_queries,
|
||||
rerank_top_n=rerank_top_n,
|
||||
collection_name=collection_name,
|
||||
)
|
||||
|
||||
@tool
|
||||
def search_knowledge_base_sync(query: str) -> str:
|
||||
"""
|
||||
在知识库中搜索与查询相关的文档片段。
|
||||
|
||||
使用混合检索(稠密向量语义 + BM25 关键词)+ 父子文档模式,
|
||||
检索效果最优。
|
||||
|
||||
Args:
|
||||
query: 用户提出的问题或查询字符串
|
||||
|
||||
Returns:
|
||||
格式化后的相关文档内容
|
||||
"""
|
||||
try:
|
||||
documents = pipeline.retrieve(query)
|
||||
if not documents:
|
||||
return f"在知识库 '{collection_name}' 中未找到与 '{query}' 相关的信息。"
|
||||
|
||||
context = pipeline.format_context(documents)
|
||||
return context
|
||||
except Exception as e:
|
||||
return f"检索过程中发生错误: {str(e)}"
|
||||
|
||||
return search_knowledge_base_sync
|
||||
|
||||
|
||||
def create_rag_tool_async(
|
||||
retriever: Optional[BaseRetriever] = None,
|
||||
llm: Optional[BaseLanguageModel] = None,
|
||||
num_queries: int = 3,
|
||||
rerank_top_n: int = 5,
|
||||
collection_name: str = "rag_documents",
|
||||
) -> Callable:
|
||||
"""
|
||||
创建一个配置好的 RAG 检索工具(异步版本)。
|
||||
|
||||
默认使用混合检索(稠密+BM25稀疏)+ 父子文档模式。
|
||||
|
||||
Args:
|
||||
retriever: 基础检索器对象(可选,不提供则自动创建)
|
||||
llm: 用于生成多路查询的语言模型(可选)
|
||||
llm: 用于生成多路查询的语言模型。
|
||||
- "default_small": (默认) 使用小模型(本地 + DeepSeek)
|
||||
- None / False: 不做查询改写
|
||||
- BaseLanguageModel 实例: 自定义模型
|
||||
num_queries: 生成的查询变体数量
|
||||
rerank_top_n: 最终返回的文档数量
|
||||
collection_name: Qdrant 集合名称
|
||||
@@ -101,9 +47,9 @@ def create_rag_tool_async(
|
||||
)
|
||||
|
||||
@tool
|
||||
async def search_knowledge_base_async(query: str) -> str:
|
||||
async def search_knowledge_base(query: str) -> str:
|
||||
"""
|
||||
在知识库中搜索与查询相关的文档片段(异步版本)。
|
||||
在知识库中搜索与查询相关的文档片段(完全异步)。
|
||||
|
||||
使用混合检索(稠密向量语义 + BM25 关键词)+ 父子文档模式,
|
||||
检索效果最优。
|
||||
@@ -124,30 +70,4 @@ def create_rag_tool_async(
|
||||
except Exception as e:
|
||||
return f"检索过程中发生错误: {str(e)}"
|
||||
|
||||
return search_knowledge_base_async
|
||||
|
||||
|
||||
def create_rag_tool(
|
||||
collection_name: str = "rag_documents",
|
||||
llm: Optional[BaseLanguageModel] = None,
|
||||
num_queries: int = 3,
|
||||
rerank_top_n: int = 5,
|
||||
) -> Callable:
|
||||
"""
|
||||
创建 RAG 检索工具的便捷函数(同步版本)。
|
||||
|
||||
Args:
|
||||
collection_name: Qdrant 集合名称
|
||||
llm: 用于生成多路查询的语言模型(可选)
|
||||
num_queries: 生成的查询变体数量
|
||||
rerank_top_n: 最终返回的文档数量
|
||||
|
||||
Returns:
|
||||
LangChain Tool 函数
|
||||
"""
|
||||
return create_rag_tool_sync(
|
||||
collection_name=collection_name,
|
||||
llm=llm,
|
||||
num_queries=num_queries,
|
||||
rerank_top_n=rerank_top_n,
|
||||
)
|
||||
return search_knowledge_base
|
||||
|
||||
Reference in New Issue
Block a user