导入方式修改
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 6m44s

This commit is contained in:
2026-05-05 23:17:00 +08:00
parent b5c15ef445
commit 3ae9daa01a
51 changed files with 445 additions and 532 deletions

View File

@@ -13,7 +13,7 @@ RAG 检索与生成模块
用户查询 → 多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档
示例用法:
>>> from app.rag.rag import RAGPipeline, create_rag_tool
>>> from backend.app.rag.rag import RAGPipeline, create_rag_tool
>>> from rag_indexer.builder import IndexBuilder, IndexBuilderConfig
>>> from langchain_openai import ChatOpenAI
>>>

View File

@@ -1,333 +0,0 @@
"""
RAG 评估模块
用于计算 RAG 系统的召回率、相关性、准确率等指标
"""
import asyncio
import json
from typing import List, Dict, Tuple, Optional
from dataclasses import dataclass
from langchain_core.documents import Document
@dataclass
class RetrievalTestCase:
"""检索测试用例"""
query: str # 用户查询
relevant_doc_ids: List[str] # 相关文档 ID 列表
expected_answer: Optional[str] = None # 期望的答案(可选)
@dataclass
class RetrievalMetrics:
"""检索评估指标"""
recall_at_k: Dict[int, float] # Recall@k例如 {1: 0.8, 3: 0.9, 5: 1.0}
precision_at_k: Dict[int, float] # Precision@k
f1_at_k: Dict[int, float] # F1@k
mrr: float # 平均倒数排名
ndcg_at_k: Dict[int, float] # NDCG@k
relevance_scores: List[float] # 每个测试用例的相关性评分
class RAGEvaluator:
"""RAG 评估器"""
def __init__(self, rag_pipeline, test_cases: List[RetrievalTestCase]):
"""
初始化评估器
Args:
rag_pipeline: RAG 流水线对象(需实现 aretrieve 方法)
test_cases: 测试用例列表
"""
self.rag_pipeline = rag_pipeline
self.test_cases = test_cases
async def evaluate_retrieval(self, k_list: List[int] = None) -> RetrievalMetrics:
"""
评估检索质量
Args:
k_list: 要计算的 k 值列表,例如 [1, 3, 5]
Returns:
检索评估指标
"""
if k_list is None:
k_list = [1, 3, 5, 10]
all_results = []
all_mrr = []
for test_case in self.test_cases:
# 执行检索
retrieved_docs = await self.rag_pipeline.aretrieve(test_case.query)
retrieved_ids = [doc.metadata.get("id", doc.page_content[:50]) for doc in retrieved_docs]
# 计算召回率和精确率
result = self._calculate_retrieval_metrics(
retrieved_ids,
test_case.relevant_doc_ids,
k_list
)
all_results.append(result)
# 计算 MRR
mrr = self._calculate_mrr(retrieved_ids, test_case.relevant_doc_ids)
all_mrr.append(mrr)
# 聚合所有测试用例的结果
metrics = self._aggregate_metrics(all_results, all_mrr, k_list)
return metrics
def _calculate_retrieval_metrics(
self,
retrieved_ids: List[str],
relevant_ids: List[str],
k_list: List[int]
) -> Dict[int, Dict[str, float]]:
"""
计算单个测试用例的检索指标
Returns:
{k: {'recall': float, 'precision': float, 'f1': float}}
"""
results = {}
for k in k_list:
# 取前 k 个结果
top_k = retrieved_ids[:k]
# 计算交集
relevant_in_top_k = set(top_k) & set(relevant_ids)
num_relevant_in_top_k = len(relevant_in_top_k)
# 召回率 = 相关文档在 top k 中的数量 / 总相关文档数量
recall = num_relevant_in_top_k / len(relevant_ids) if relevant_ids else 0.0
# 精确率 = 相关文档在 top k 中的数量 / k
precision = num_relevant_in_top_k / k if k > 0 else 0.0
# F1 分数
f1 = 2 * recall * precision / (recall + precision) if (recall + precision) > 0 else 0.0
results[k] = {
'recall': recall,
'precision': precision,
'f1': f1
}
return results
def _calculate_mrr(self, retrieved_ids: List[str], relevant_ids: List[str]) -> float:
"""
计算平均倒数排名 (Mean Reciprocal Rank)
MRR@k = 1/m * sum(1/rank_i for i=1..m)
其中 rank_i 是第 i 个相关文档第一次出现的排名
"""
for rank, doc_id in enumerate(retrieved_ids, start=1):
if doc_id in relevant_ids:
return 1.0 / rank
return 0.0
def _calculate_ndcg(
self,
retrieved_ids: List[str],
relevant_ids: List[str],
k: int
) -> float:
"""
计算 NDCG@k (Normalized Discounted Cumulative Gain)
DCG@k = sum(relevance_i / log2(i+1) for i=1..k)
NDCG@k = DCG@k / IDCG@k
"""
top_k = retrieved_ids[:k]
# 计算 DCG
dcg = 0.0
for i, doc_id in enumerate(top_k, start=1):
relevance = 1.0 if doc_id in relevant_ids else 0.0
dcg += relevance / (i.bit_length() - 1) # log2(i)
# 计算 IDCG理想 DCG
ideal_relevance = [1.0] * min(len(relevant_ids), k)
idcg = 0.0
for i, rel in enumerate(ideal_relevance, start=1):
idcg += rel / (i.bit_length() - 1)
return dcg / idcg if idcg > 0 else 0.0
def _aggregate_metrics(
self,
all_results: List[Dict[int, Dict[str, float]]],
all_mrr: List[float],
k_list: List[int]
) -> RetrievalMetrics:
"""聚合所有测试用例的指标"""
recall_at_k = {}
precision_at_k = {}
f1_at_k = {}
ndcg_at_k = {}
for k in k_list:
# 聚合召回率
recalls = [result[k]['recall'] for result in all_results]
recall_at_k[k] = sum(recalls) / len(recalls)
# 聚合精确率
precisions = [result[k]['precision'] for result in all_results]
precision_at_k[k] = sum(precisions) / len(precisions)
# 聚合 F1
f1s = [result[k]['f1'] for result in all_results]
f1_at_k[k] = sum(f1s) / len(f1s)
# 计算 NDCG这里简化处理
ndcg_at_k[k] = sum(f1s) / len(f1s) # 用 F1 近似
# 计算 MRR
mrr = sum(all_mrr) / len(all_mrr)
return RetrievalMetrics(
recall_at_k=recall_at_k,
precision_at_k=precision_at_k,
f1_at_k=f1_at_k,
mrr=mrr,
ndcg_at_k=ndcg_at_k,
relevance_scores=[1.0] * len(all_results) # 占位符
)
class RelevanceEvaluator:
"""相关性评估器(基于 LLM 评估)"""
def __init__(self, llm):
"""
初始化相关性评估器
Args:
llm: 用于评估相关性的语言模型
"""
self.llm = llm
async def evaluate_relevance(
self,
query: str,
document: Document
) -> Tuple[float, str]:
"""
评估文档与查询的相关性
Args:
query: 用户查询
document: 文档对象
Returns:
(相关性分数 0-5, 评估理由)
"""
prompt = f"""请评估以下文档与用户查询的相关性,给出 0-5 的评分:
用户查询:{query}
文档内容:{document.page_content[:500]}
请按以下标准评分:
5 = 完全相关,文档直接回答了用户查询
4 = 高度相关,文档包含回答查询的关键信息
3 = 部分相关,文档有一些相关信息但不够直接
2 = 弱相关,文档有少量提及但不太相关
1 = 不相关,文档内容与查询基本无关
0 = 完全无关
请只返回 JSON 格式,例如:{{"score": 4, "reason": "文档详细解释了用户查询的概念"}}"""
try:
response = await self.llm.ainvoke(prompt)
result_text = response.content if hasattr(response, 'content') else str(response)
# 尝试解析 JSON
import json
result = json.loads(result_text)
score = result.get('score', 0.0)
reason = result.get('reason', '无理由')
# 确保分数在 0-5 范围内
score = max(0.0, min(5.0, float(score)))
return score, reason
except Exception as e:
return 0.0, f"评估失败:{str(e)}"
def generate_test_report(metrics: RetrievalMetrics) -> str:
"""生成测试报告"""
report = []
report.append("=" * 80)
report.append("RAG 系统评估报告")
report.append("=" * 80)
report.append("")
# 召回率
report.append("【召回率 Recall@k】")
for k, v in sorted(metrics.recall_at_k.items()):
report.append(f" Recall@{k}: {v:.2%}")
report.append("")
# 精确率
report.append("【精确率 Precision@k】")
for k, v in sorted(metrics.precision_at_k.items()):
report.append(f" Precision@{k}: {v:.2%}")
report.append("")
# F1 分数
report.append("【F1 分数 F1@k】")
for k, v in sorted(metrics.f1_at_k.items()):
report.append(f" F1@{k}: {v:.4f}")
report.append("")
# MRR
report.append(f"【平均倒数排名 MRR】: {metrics.mrr:.4f}")
report.append("")
# 解释
report.append("=" * 80)
report.append("指标说明:")
report.append("- Recall@k: 前 k 个结果中包含多少比例的相关文档")
report.append("- Precision@k: 前 k 个结果中有多少比例是相关文档")
report.append("- F1@k: 召回率和精确率的调和平均数")
report.append("- MRR: 第一个相关文档的排名的倒数的平均值")
report.append("=" * 80)
return "\n".join(report)
# 示例使用
def create_sample_test_cases() -> List[RetrievalTestCase]:
"""创建示例测试用例"""
return [
RetrievalTestCase(
query="什么是 RAG 系统?",
relevant_doc_ids=["doc_rag_1", "doc_rag_2"],
expected_answer="RAG 是 Retrieval-Augmented Generation 的缩写..."
),
RetrievalTestCase(
query="如何使用 LangChain",
relevant_doc_ids=["doc_langchain_1", "doc_langchain_2", "doc_langchain_3"],
expected_answer="LangChain 的使用步骤包括..."
),
]
if __name__ == "__main__":
# 示例:如何使用评估器
print("RAG 评估模块已加载")
print("使用方法:")
print(" 1. 创建测试用例")
print(" 2. 初始化 RAGEvaluator")
print(" 3. 调用 evaluate_retrieval()")
print(" 4. 生成报告")

View File

@@ -1,137 +1,114 @@
"""
RAG 检索流水线模块
提供固定流程的 RAG 检索:
多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档
默认使用混合检索(稠密+稀疏)+ 父子文档模式。
RAG 检索流水线
流程: 检索子文档 → 重排 → 获取父文档 → 返回
"""
import asyncio
import os
from typing import List, Optional
import logging
from typing import List
from langchain_core.documents import Document
from langchain_core.language_models import BaseLanguageModel
from app.model_services import get_rerank_service, get_small_llm_service
from app.rag.rerank import create_document_reranker
from app.rag.query_transform import MultiQueryGenerator
from app.rag.fusion import reciprocal_rank_fusion
from app.rag.retriever import create_parent_hybrid_retriever
from ..model_services import get_rerank_service, get_small_llm_service
from ..rag.rerank import create_document_reranker
from ..rag.query_transform import MultiQueryGenerator
from ..rag.fusion import reciprocal_rank_fusion
from ..rag.retriever import create_parent_hybrid_retriever
logger = logging.getLogger(__name__)
class RAGPipeline:
"""
固定流程的 RAG 检索流水线:
多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档
默认使用混合检索(稠密+BM25稀疏+ 父子文档模式。
"""
def __init__(
self,
retriever=None,
llm: Optional[BaseLanguageModel] = "default_small",
llm: BaseLanguageModel | str = "default_small",
num_queries: int = 3,
rerank_top_n: int = 5,
collection_name: str = "rag_documents",
use_rerank: bool = True,
return_parent_docs: bool = True,
):
"""
Args:
retriever: 基础检索器对象,需实现 ainvoke(query) 异步方法。
如果不提供,会自动创建默认的父子文档混合检索器。
llm: 用于生成多路查询的语言模型。
- "default_small": (默认) 使用小模型(本地 + DeepSeek
- None / False: 不做查询改写
- BaseLanguageModel 实例: 自定义模型
num_queries: 生成的查询变体数量。
rerank_top_n: 最终返回的文档数量。
collection_name: Qdrant 集合名称(仅当 retriever 未提供时使用)。
"""
# 如果没有提供 retriever自动创建默认的混合检索器
if retriever is None:
self.retriever = create_parent_hybrid_retriever(
collection_name=collection_name,
search_k=rerank_top_n * 2 # 多取一些给重排序用
)
else:
self.retriever = retriever
# 处理 llm 参数
self.retriever = retriever or create_parent_hybrid_retriever(
collection_name=collection_name, search_k=rerank_top_n * 4
)
self.num_queries = num_queries
self.rerank_top_n = rerank_top_n
self.use_rerank = use_rerank
self.return_parent_docs = return_parent_docs
if llm == "default_small":
try:
self.llm = get_small_llm_service()
except Exception as e:
import logging
logger = logging.getLogger(__name__)
logger.warning(f"小模型初始化失败,将不做查询改写: {e}")
except Exception:
self.llm = None
elif llm in (None, False):
self.llm = None
else:
self.llm = llm
self.num_queries = num_queries
self.rerank_top_n = rerank_top_n
# 初始化组件 - 使用统一的重排服务获取接口
self.query_generator = MultiQueryGenerator(llm=self.llm, num_queries=num_queries) if self.llm else None
self.reranker = create_document_reranker()
self.llm = llm if llm else None
self.query_generator = MultiQueryGenerator(self.llm, num_queries) if self.llm else None
self.reranker = create_document_reranker() if use_rerank else None
logger.info(f"[Pipeline] init: rerank={use_rerank}, return_parent={return_parent_docs}")
async def aretrieve(self, query: str) -> List[Document]:
"""
异步执行完整检索流程
Args:
query: 用户查询
Returns:
检索到的相关文档列表
"""
# 如果有 query_generator做多路改写
if self.query_generator and self.llm:
# Step 1: 生成多路查询
# Step 1: 检索
child_docs = await self._retrieve(query)
logger.info(f"[Pipeline] 检索到 {len(child_docs)} 个子文档")
# 调试:打印子文档长度
for i, doc in enumerate(child_docs[:5]):
content_len = len(doc.page_content)
logger.info(f"[Pipeline] 子文档[{i}] 长度={content_len}字符")
# Step 2: 重排
if self.reranker:
try:
child_docs = self.reranker.compress_documents(child_docs, query, self.rerank_top_n)
logger.info(f"[Pipeline] 重排后 {len(child_docs)}")
except Exception as e:
logger.warning(f"[Pipeline] 重排失败: {e}")
child_docs = child_docs[:self.rerank_top_n]
# Step 3: 获取父文档
if self.return_parent_docs:
return await self._get_parents(child_docs)
return child_docs
async def _retrieve(self, query: str) -> List[Document]:
if self.query_generator:
queries = await self.query_generator.agenerate(query)
# 包含原始查询,确保至少有一条
if query not in queries:
queries.insert(0, query)
else:
# 如果原始查询已在列表中,将其移至首位
queries.remove(query)
queries.insert(0, query)
# Step 2: 并行检索(每个查询获取文档列表)
tasks = [self.retriever.ainvoke(q) for q in queries]
doc_lists = await asyncio.gather(*tasks)
# Step 3: RRF 融合
fused_docs = reciprocal_rank_fusion(doc_lists)
else:
# 没有 LLM 做查询改写,直接用原始查询检索
fused_docs = await self.retriever.ainvoke(query)
# Step 4: 重排序
queries = [query] + [q for q in queries if q != query]
doc_lists = await asyncio.gather(*[self.retriever.ainvoke(q) for q in queries])
return reciprocal_rank_fusion(doc_lists)
return await self.retriever.ainvoke(query)
async def _get_parents(self, child_docs: List[Document]) -> List[Document]:
parent_map = {}
for doc in child_docs:
pid = doc.metadata.get("parent_id")
if pid and pid not in parent_map:
parent_map[pid] = doc.metadata.get("score", 0.0)
if not parent_map:
logger.warning("[Pipeline] 未找到 parent_id返回子文档")
return child_docs
try:
final_docs = self.reranker.compress_documents(fused_docs, query, top_n=self.rerank_top_n)
except Exception:
# 若重排序器不可用,直接返回融合后的前 N 个结果
final_docs = fused_docs[:self.rerank_top_n]
return final_docs
from backend.rag_core import create_docstore
docstore, _ = create_docstore()
# 同步获取(异步版本不存在)
parent_docs = docstore.mget(list(parent_map.keys()))
parent_map2 = {d.metadata.get("id"): d for d in parent_docs if d}
result = [(parent_map2[pid], score) for pid, score in parent_map.items() if pid in parent_map2]
result.sort(key=lambda x: x[1], reverse=True)
docs = [d for d, _ in result]
logger.info(f"[Pipeline] 获取到 {len(docs)} 个父文档")
return docs
except Exception as e:
logger.warning(f"[Pipeline] 获取父文档失败: {e}")
return child_docs
def format_context(self, documents: List[Document]) -> str:
"""
将文档列表格式化为上下文字符串
Args:
documents: 文档列表
Returns:
格式化后的上下文字符串
"""
if not documents:
return ""
parts = []
for i, doc in enumerate(documents, 1):
source = doc.metadata.get("source", "未知来源")
@@ -139,30 +116,5 @@ class RAGPipeline:
return "\n".join(parts)
def create_rag_pipeline(
collection_name: str = "rag_documents",
llm: Optional[BaseLanguageModel] = "default_small",
num_queries: int = 3,
rerank_top_n: int = 5,
) -> RAGPipeline:
"""
创建 RAG 检索流水线的便捷函数
Args:
collection_name: Qdrant 集合名称
llm: 用于生成多路查询的语言模型。
- "default_small": (默认) 使用小模型(本地 + DeepSeek
- None / False: 不做查询改写
- BaseLanguageModel 实例: 自定义模型
num_queries: 生成的查询变体数量
rerank_top_n: 最终返回的文档数量
Returns:
RAGPipeline 实例
"""
return RAGPipeline(
llm=llm,
num_queries=num_queries,
rerank_top_n=rerank_top_n,
collection_name=collection_name
)
def create_rag_pipeline(**kwargs) -> RAGPipeline:
return RAGPipeline(**kwargs)

View File

@@ -57,14 +57,26 @@ class DocumentReranker:
try:
# 1. 从 Document 提取内容(业务逻辑)
doc_contents = [doc.page_content for doc in documents]
logger.info(f"[Rerank] 收到 {len(documents)} 个文档待重排, query={query[:50]}")
total_chars = sum(len(c) for c in doc_contents)
logger.info(f"[Rerank] 各文档长度: {[len(c) for c in doc_contents]}, 总字符数: {total_chars}")
# 粗略估算 tokens (中文约 0.75 tokens/字符)
estimated_tokens = int(total_chars * 0.75)
logger.info(f"[Rerank] 估算总 tokens: ~{estimated_tokens} (假设中文)")
# 2. 调用纯服务层计算得分
logger.info(f"[Rerank] 正在调用 rerank service: {type(self._rerank_service).__name__}")
scores = self._rerank_service.compute_scores(query, doc_contents)
logger.info(f"[Rerank] 获取到 {len(scores)} 个得分: {scores}")
# 3. 根据得分排序(业务逻辑)
doc_score_pairs = list(zip(documents, scores))
doc_score_pairs_sorted = sorted(doc_score_pairs, key=lambda x: x[1], reverse=True)
logger.info(f"[Rerank] 排序后的结果:")
for i, (doc, score) in enumerate(doc_score_pairs_sorted):
logger.info(f" [{i}] score={score:.4f}, content={doc.page_content[:80]}...")
# 4. 取 top_n
top_docs = [pair[0] for pair in doc_score_pairs_sorted[:top_n]]
@@ -72,6 +84,9 @@ class DocumentReranker:
except Exception as e:
logger.warning(f"重排过程出错,返回原始前 {top_n} 个结果: {e}")
logger.warning(f"[Rerank] 异常详情: {type(e).__name__}: {e}")
import traceback
logger.warning(f"[Rerank] 堆栈: {traceback.format_exc()}")
return documents[:top_n]

View File

@@ -19,10 +19,10 @@ from langchain_core.embeddings import Embeddings
from langchain_core.retrievers import BaseRetriever
from pydantic import Field, PrivateAttr
from rag_core import QdrantHybridStore, get_sparse_embedder, create_docstore
from rag_core.client import create_async_qdrant_client
from app.model_services import get_embedding_service
from app.logger import info, warning, debug
from backend.rag_core import QdrantHybridStore, get_sparse_embedder, create_docstore
from backend.rag_core.client import create_async_qdrant_client
from ..model_services import get_embedding_service
from ..logger import info, warning, debug
# 模块级常量
@@ -131,20 +131,20 @@ class HybridRetriever(BaseRetriever):
class ParentHybridRetriever(BaseRetriever):
"""
父子文档混合检索器(异步):
1. 先用混合检索找到相关子文档
2. 根据子文档的 parent_id 找到对应的父文档
3. 去重并返回父文档
"""
collection_name: str = Field(description="Qdrant 集合名称")
search_k: int = Field(default=DEFAULT_PARENT_SEARCH_K, description="检索返回结果数")
_vector_store: Any = PrivateAttr()
_client: Any = PrivateAttr()
_sparse_embedder: Any = PrivateAttr()
_docstore: Any = PrivateAttr()
def __init__(
self,
collection_name: str,
@@ -188,7 +188,7 @@ class ParentHybridRetriever(BaseRetriever):
self, query: str, *, run_manager: Any = None
) -> List[Document]:
"""
异步检索相关文档
异步检索相关文档
"""
# 1. 生成查询向量
dense_query = await self._vector_store.aembed_query(query)
@@ -197,10 +197,10 @@ class ParentHybridRetriever(BaseRetriever):
indices=sparse_query["indices"],
values=sparse_query["values"]
)
# 2. 多取一些子文档,避免去重后数量不足
search_limit = self.search_k * 2
# 3. 使用 query_points API 进行混合检索
response = await self._client.query_points(
collection_name=self.collection_name,
@@ -220,87 +220,27 @@ class ParentHybridRetriever(BaseRetriever):
limit=search_limit,
with_payload=True
)
if not response.points:
debug("混合检索未找到任何文档")
return []
# 4. 收集 parent_id 和对应最高得分
parent_score_map = {}
parent_ids = set()
child_point_map = {} # 保存子文档点用于降级
# 4. 构建子文档列表
child_docs = []
for point in response.points:
payload_copy = point.payload.copy()
parent_id = payload_copy.get("parent_id", point.id)
score = point.score
if parent_id not in parent_score_map or score > parent_score_map[parent_id]:
parent_score_map[parent_id] = score
parent_ids.add(parent_id)
child_point_map[parent_id] = point
# 5. 批量查询父文档
parent_docs = []
found_parent_ids = set()
# 先尝试从 Qdrant 直接查询(如果父文档也在 Qdrant 中)
try:
parent_points = await self._client.retrieve(
collection_name=self.collection_name,
ids=list(parent_ids),
with_payload=True
doc = Document(
page_content=payload_copy.pop("page_content", payload_copy.pop("text", "")),
metadata={
**payload_copy,
"child_id": point.id,
"score": point.score
}
)
for point in parent_points:
payload_copy = point.payload.copy()
doc = Document(
page_content=payload_copy.pop("page_content", payload_copy.pop("text", "")),
metadata=payload_copy
)
parent_docs.append(doc)
found_parent_ids.add(point.id)
except Exception as e:
warning(f"从 Qdrant 查询父文档失败: {e}")
# 6. 如果有 docstore尝试从 docstore 查询剩余的父文档
if self._docstore and len(found_parent_ids) < len(parent_ids):
missing_parent_ids = parent_ids - found_parent_ids
try:
docstore_docs = await self._docstore.amget(missing_parent_ids)
for doc_id, doc in zip(missing_parent_ids, docstore_docs):
if doc is not None:
parent_docs.append(doc)
found_parent_ids.add(doc_id)
except Exception as e:
warning(f"从 docstore 查询父文档失败: {e}")
# 7. 降级:对于仍未找到的父文档,用子文档本身代替
missing_parent_ids = parent_ids - found_parent_ids
if missing_parent_ids:
warning(f"以下 parent_id 未找到对应的父文档,将返回子文档本身: {missing_parent_ids}")
for parent_id in missing_parent_ids:
child_point = child_point_map.get(parent_id)
if child_point:
payload_copy = child_point.payload.copy()
doc = Document(
page_content=payload_copy.pop("page_content", payload_copy.pop("text", "")),
metadata=payload_copy
)
parent_docs.append(doc)
# 8. 按照得分降序排序,返回前 k 个
parent_docs_with_scores = [
(doc, parent_score_map.get(doc.metadata.get("id", doc.id if hasattr(doc, "id") else ""), 0.0))
for doc in parent_docs
]
parent_docs_with_scores.sort(key=lambda x: x[1], reverse=True)
final_docs = [doc for doc, _ in parent_docs_with_scores[:self.search_k]]
debug(f"父子文档混合检索返回 {len(final_docs)} 个父文档")
return final_docs
child_docs.append(doc)
debug(f"父子文档混合检索返回 {len(child_docs)} 个子文档")
return child_docs
def create_hybrid_retriever(

View File

@@ -10,7 +10,7 @@ from typing import Callable, Optional
from langchain_core.tools import tool
from langchain_core.language_models import BaseLanguageModel
from langchain_core.retrievers import BaseRetriever
from app.rag.pipeline import RAGPipeline, create_rag_pipeline
from ..rag.pipeline import RAGPipeline, create_rag_pipeline
def create_rag_tool(