修改日志用项目统一的 logger
All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 10m46s

This commit is contained in:
2026-05-06 16:15:09 +08:00
parent 13499ecf2a
commit d09b0d16ce
2 changed files with 46 additions and 49 deletions

View File

@@ -1,10 +1,9 @@
# rag/fusion.py
import logging
from typing import List, Dict
from langchain_core.documents import Document
from backend.app.logger import info
logger = logging.getLogger(__name__)
def reciprocal_rank_fusion(
doc_lists: List[List[Document]],
@@ -12,22 +11,22 @@ def reciprocal_rank_fusion(
) -> List[Document]:
"""
对多个检索结果列表进行 RRF 融合。
Args:
doc_lists: 多个检索结果列表,每个列表来自一个查询
k: RRF 常数,通常设为 60
Returns:
融合后按 RRF 得分降序排列的文档列表
"""
logger.info(f"[RRF] reciprocal_rank_fusion 开始: {len(doc_lists)} 组文档")
info(f"[RRF] reciprocal_rank_fusion 开始: {len(doc_lists)} 组文档")
# 使用文档内容作为唯一标识(如果内容相同但 metadata 不同,视为同一文档)
# 更好的做法是用 docstore 的 ID这里简化处理用内容 hash
doc_to_score: Dict[str, float] = {}
doc_map: Dict[str, Document] = {}
for list_idx, docs in enumerate(doc_lists):
logger.info(f"[RRF] 处理第 {list_idx} 组: {len(docs)} 个文档")
info(f"[RRF] 处理第 {list_idx} 组: {len(docs)} 个文档")
for rank, doc in enumerate(docs, start=1):
# 生成唯一标识符(内容+来源组合,避免不同文件相同内容混淆)
doc_id = f"{doc.page_content[:200]}_{doc.metadata.get('source', '')}"
@@ -35,10 +34,10 @@ def reciprocal_rank_fusion(
doc_map[doc_id] = doc
score = doc_to_score.get(doc_id, 0.0) + 1.0 / (k + rank)
doc_to_score[doc_id] = score
logger.info(f"[RRF] 去重后共 {len(doc_map)} 个唯一文档")
info(f"[RRF] 去重后共 {len(doc_map)} 个唯一文档")
# 按得分排序
sorted_ids = sorted(doc_to_score.keys(), key=lambda x: doc_to_score[x], reverse=True)
result = [doc_map[doc_id] for doc_id in sorted_ids]
logger.info(f"[RRF] reciprocal_rank_fusion 结束: 返回 {len(result)} 个文档")
return result
info(f"[RRF] reciprocal_rank_fusion 结束: 返回 {len(result)} 个文档")
return result

View File

@@ -4,19 +4,17 @@ RAG 检索流水线
"""
import asyncio
import logging
from typing import List
from langchain_core.documents import Document
from langchain_core.language_models import BaseLanguageModel
from backend.app.logger import info, warning
from ..model_services import get_rerank_service, get_small_llm_service
from ..rag.rerank import create_document_reranker
from ..rag.query_transform import MultiQueryGenerator
from ..rag.fusion import reciprocal_rank_fusion
from ..rag.retriever import create_parent_hybrid_retriever
logger = logging.getLogger(__name__)
class RAGPipeline:
def __init__(
@@ -49,7 +47,7 @@ class RAGPipeline:
self.query_generator = MultiQueryGenerator(self.llm, num_queries) if self.llm else None
self.reranker = create_document_reranker() if use_rerank else None
logger.info(f"[Pipeline] init: rerank={use_rerank}, return_parent={return_parent_docs}")
info(f"[Pipeline] init: rerank={use_rerank}, return_parent={return_parent_docs}")
@property
def last_docs(self) -> List[Document]:
@@ -62,40 +60,40 @@ class RAGPipeline:
return self._last_scores
async def aretrieve(self, query: str) -> List[Document]:
logger.info(f"[Pipeline] aretrieve 开始: query={query[:50]}...")
info(f"[Pipeline] aretrieve 开始: query={query[:50]}...")
# Step 1: 检索
logger.info(f"[Pipeline] Step 1: 调用 _retrieve")
info(f"[Pipeline] Step 1: 调用 _retrieve")
child_docs = await self._retrieve(query)
logger.info(f"[Pipeline] Step 1 完成: 检索到 {len(child_docs)} 个子文档")
info(f"[Pipeline] Step 1 完成: 检索到 {len(child_docs)} 个子文档")
# 调试:打印子文档长度
for i, doc in enumerate(child_docs[:5]):
content_len = len(doc.page_content)
logger.info(f"[Pipeline] 子文档[{i}] 长度={content_len}字符")
info(f"[Pipeline] 子文档[{i}] 长度={content_len}字符")
# Step 2: 重排
logger.info(f"[Pipeline] Step 2: 开始重排")
info(f"[Pipeline] Step 2: 开始重排")
if self.reranker:
try:
child_docs = self.reranker.compress_documents(child_docs, query, self.rerank_top_n)
logger.info(f"[Pipeline] Step 2 完成: 重排后 {len(child_docs)}")
info(f"[Pipeline] Step 2 完成: 重排后 {len(child_docs)}")
except Exception as e:
logger.warning(f"[Pipeline] 重排失败: {e}")
warning(f"[Pipeline] 重排失败: {e}")
child_docs = child_docs[:self.rerank_top_n]
else:
logger.info(f"[Pipeline] Step 2 跳过: 未启用 reranker")
info(f"[Pipeline] Step 2 跳过: 未启用 reranker")
# Step 3: 获取父文档
logger.info(f"[Pipeline] Step 3: 开始获取父文档")
info(f"[Pipeline] Step 3: 开始获取父文档")
if self.return_parent_docs:
parent_docs = await self._get_parents(child_docs)
logger.info(f"[Pipeline] Step 3 完成: 获取到 {len(parent_docs)} 个父文档")
info(f"[Pipeline] Step 3 完成: 获取到 {len(parent_docs)} 个父文档")
# 保存分数信息到 last_scores 供外部访问
self._last_scores = self._extract_scores(parent_docs)
logger.info(f"[Pipeline] aretrieve 结束: 返回父文档")
info(f"[Pipeline] aretrieve 结束: 返回父文档")
return parent_docs
self._last_scores = self._extract_scores(child_docs)
logger.info(f"[Pipeline] aretrieve 结束: 返回子文档")
info(f"[Pipeline] aretrieve 结束: 返回子文档")
return child_docs
def _extract_scores(self, docs: List[Document]) -> List[dict]:
@@ -109,27 +107,27 @@ class RAGPipeline:
return scores
async def _retrieve(self, query: str) -> List[Document]:
logger.info(f"[Pipeline] _retrieve 开始: query={query[:50]}...")
info(f"[Pipeline] _retrieve 开始: query={query[:50]}...")
if self.query_generator:
logger.info(f"[Pipeline] _retrieve: 调用 query_generator.agenerate")
info(f"[Pipeline] _retrieve: 调用 query_generator.agenerate")
queries = await self.query_generator.agenerate(query)
queries = [query] + [q for q in queries if q != query]
logger.info(f"[Pipeline] _retrieve: 生成 {len(queries)} 个查询: {queries}")
logger.info(f"[Pipeline] _retrieve: 开始 asyncio.gather 并行检索")
info(f"[Pipeline] _retrieve: 生成 {len(queries)} 个查询: {queries}")
info(f"[Pipeline] _retrieve: 开始 asyncio.gather 并行检索")
doc_lists = await asyncio.gather(*[self.retriever.ainvoke(q) for q in queries])
logger.info(f"[Pipeline] _retrieve: asyncio.gather 完成,得到 {len(doc_lists)} 组结果")
logger.info(f"[Pipeline] _retrieve: 开始 reciprocal_rank_fusion")
info(f"[Pipeline] _retrieve: asyncio.gather 完成,得到 {len(doc_lists)} 组结果")
info(f"[Pipeline] _retrieve: 开始 reciprocal_rank_fusion")
result = reciprocal_rank_fusion(doc_lists)
logger.info(f"[Pipeline] _retrieve: RRF 完成,得到 {len(result)} 个文档")
logger.info(f"[Pipeline] _retrieve 结束")
info(f"[Pipeline] _retrieve: RRF 完成,得到 {len(result)} 个文档")
info(f"[Pipeline] _retrieve 结束")
return result
logger.info(f"[Pipeline] _retrieve: query_generator 未启用,直接单次检索")
info(f"[Pipeline] _retrieve: query_generator 未启用,直接单次检索")
result = await self.retriever.ainvoke(query)
logger.info(f"[Pipeline] _retrieve 结束")
info(f"[Pipeline] _retrieve 结束")
return result
async def _get_parents(self, child_docs: List[Document]) -> List[Document]:
logger.info(f"[Pipeline] _get_parents 开始: {len(child_docs)} 个子文档")
info(f"[Pipeline] _get_parents 开始: {len(child_docs)} 个子文档")
# 收集 parent_id 和对应的分数
parent_map = {} # parent_id -> (embedding_score, rerank_score)
@@ -142,18 +140,18 @@ class RAGPipeline:
rerank_score = doc.metadata.get("rerank_score", 0.0)
parent_map[pid] = (embedding_score, rerank_score)
logger.info(f"[Pipeline] _get_parents: 收集到 {len(parent_map)} 个 unique parent_id")
info(f"[Pipeline] _get_parents: 收集到 {len(parent_map)} 个 unique parent_id")
if not parent_map:
logger.warning("[Pipeline] 未找到 parent_id返回子文档")
warning("[Pipeline] 未找到 parent_id返回子文档")
return child_docs
try:
logger.info(f"[Pipeline] _get_parents: 调用 create_docstore")
info(f"[Pipeline] _get_parents: 调用 create_docstore")
from backend.rag_core import create_docstore
docstore, _ = create_docstore()
logger.info(f"[Pipeline] _get_parents: 调用 docstore.amget")
info(f"[Pipeline] _get_parents: 调用 docstore.amget")
parent_docs =await docstore.amget(list(parent_map.keys()))
logger.info(f"[Pipeline] _get_parents: docstore.amget 返回 {len(parent_docs)} 个结果")
info(f"[Pipeline] _get_parents: docstore.amget 返回 {len(parent_docs)} 个结果")
# 构建结果,保持分数信息
result = []
@@ -168,24 +166,24 @@ class RAGPipeline:
result.sort(key=lambda x: x[1], reverse=True)
docs = [d for d, _ in result]
logger.info(f"[Pipeline] _get_parents: 最终得到 {len(docs)} 个父文档")
logger.info(f"[Pipeline] _get_parents 结束")
info(f"[Pipeline] _get_parents: 最终得到 {len(docs)} 个父文档")
info(f"[Pipeline] _get_parents 结束")
return docs
except Exception as e:
logger.warning(f"[Pipeline] 获取父文档失败: {e}", exc_info=True)
warning(f"[Pipeline] 获取父文档失败: {e}", exc_info=True)
return child_docs
def format_context(self, documents: List[Document]) -> str:
logger.info(f"[Pipeline] format_context 开始: {len(documents)} 个文档")
info(f"[Pipeline] format_context 开始: {len(documents)} 个文档")
if not documents:
logger.info(f"[Pipeline] format_context: 无文档,返回空字符串")
info(f"[Pipeline] format_context: 无文档,返回空字符串")
return ""
parts = []
for i, doc in enumerate(documents, 1):
source = doc.metadata.get("source", "未知来源")
parts.append(f"【资料 {i}】来源:{source}\n{doc.page_content}\n---\n")
result = "\n".join(parts)
logger.info(f"[Pipeline] format_context 结束: 结果长度={len(result)} 字符")
info(f"[Pipeline] format_context 结束: 结果长度={len(result)} 字符")
return result