添加rag置信度判断
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 6m31s

This commit is contained in:
2026-05-06 01:15:52 +08:00
parent 3ae9daa01a
commit 1260bef5cb
35 changed files with 335 additions and 221 deletions

View File

@@ -1481,7 +1481,7 @@ mkdir backend/app/subgraphs/my_subgraph
2. **创建状态定义 (state.py)**
```python
from typing_extensions import TypedDict
from ..core.state_base import BaseSubgraphState
from backend.app.core.state_base import BaseSubgraphState
class MySubgraphState(BaseSubgraphState):
\"\"\"

View File

@@ -16,8 +16,8 @@ from ..main_graph.main_graph_builder import build_react_main_graph
from ..main_graph.tools.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
from ..main_graph.config import set_stream_writer
from ..main_graph.utils.rag_initializer import init_rag_tool
from ..core.intent_classifier import get_intent_classifier
from ..logger import debug, info, warning, error
from backend.app.core.intent_classifier import get_intent_classifier
from backend.app.logger import debug, info, warning, error
from ..main_graph.state import MainGraphState, CurrentAction

View File

@@ -4,7 +4,7 @@
"""
from typing import List, Dict, Any
from ..logger import error # 保持兼容,或者替换为 logger
from backend.app.logger import error # 保持兼容,或者替换为 logger
class ThreadHistoryService:
"""线程历史查询服务"""

View File

@@ -3,9 +3,10 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
def create_system_prompt(tools: list = None) -> ChatPromptTemplate:
"""
创建系统提示模板,可选择动态注入工具描述
创建系统提示模板,整合多子系统能力、检索策略与回答规范
"""
tools_section = ""
# 构造工具描述
tools_section = "无可用工具"
if tools:
tool_descs = []
for tool in tools:
@@ -14,25 +15,44 @@ def create_system_prompt(tools: list = None) -> ChatPromptTemplate:
tool_descs.append(f"- {name}: {desc}")
tools_section = "\n".join(tool_descs)
system_template = (
"你是一个智能助手,具有三个专业子系统和RAG检索能力请使用中文交流。\n\n"
"【核心功能】\n"
"1. 📚 词典/翻译子系统 - 查询单词、翻译文本、提取术语、每日一词\n"
"2. 📰 资讯分析子系统 - 查询新闻、分析URL、提取关键词、生成报告\n"
"3. 📇 通讯录子系统 - 查询联系人、添加联系人、管理通讯录\n"
"4. 🔍 RAG检索 - 从知识库中检索相关信息回答问题\n\n"
"【用户背景信息】\n"
"以下是对当前用户的已知信息和长期记忆,你必须优先采纳:\n"
"{memory_context}\n"
"【可用工具与使用规则】\n"
f"{tools_section}\n"
"工具调用时请直接返回所需参数,无需额外说明。\n\n"
"【回答要求(必须遵守)】\n"
"1. 回答必须简洁、直接。\n"
"2. 如果你认为该问题需要进行深入的推理或思考,请务必将你的思维链或推理过程用 `<think>` 和 `</think>` 标签包裹起来,放在回答的最前面。\n"
"3. 优先利用已知用户信息进行个性化回复。\n"
"4. 若无信息可依,礼貌询问或提供通用帮助。"
)
# 使用 f-string 将 tools_section 直接嵌入,而 memory_context 用双花括号转义保留为变量
system_template = f'''你是一个智能助手,具备以下专业子系统和检索能力请使用中文交流。
## 核心功能
1. 📚 词典/翻译子系统 查询单词、翻译文本、提取术语、每日一词
2. 📰 资讯分析子系统 查询新闻、分析URL、提取关键词、生成报告
3. 📇 通讯录子系统 查询联系人、添加联系人、管理通讯录
4. 🔍 RAG检索 从知识库中检索相关信息回答问题
## 检索与信息获取策略
当收到用户问题时,请按以下优先级处理:
1. **RAG 检索第1次**:首先尝试从知识库中查找答案。
2. **ReRAG第2次优化检索**:如果第一次检索结果不相关或不充分,可以优化查询后再次进行 RAG 检索。
3. **联网搜索**:如果两次 RAG 检索后仍无法获得满意答案,必须使用联网搜索获取最新信息。
**重要约束**
- 最多进行 **2 次** RAG 检索尝试。
- 第3次决定获取信息时必须选择**联网搜索**,禁止无休止的本地检索。
- 如果已经明确知识库不包含该信息(例如用户询问实时新闻),可以直接进入联网搜索。
## 可用工具
{tools_section}
工具调用时请直接返回所需参数,无需额外说明。
## 用户背景信息
以下是当前用户的已知信息和长期记忆,你应在回答中优先利用这些信息进行个性化回复:
{{memory_context}}
若无相关信息,可礼貌询问或提供通用帮助。
## 回答要求(必须严格遵守)
1. **来源标注**:回答开头必须明确标注信息来源,格式如下:
- 使用知识库时:`【知识库:来源描述】`
- 使用联网搜索时:`【联网搜索:来源描述】`
- 若同时用到多个来源,按实际使用顺序标注,例如:`【知识库:三国演义】【联网搜索:百度百科】`
2. **思维链**:如果问题需要深入推理或复杂思考,请将推理过程用 `<think>...</think>` 标签包裹,放在回答最前面(来源标注之前)。
3. **简洁直接**:回答应重点突出、条理清晰,避免冗长。
4. **个性化**:结合用户信息进行针对性回复。
5. **无依据时**:若既无知识库支撑也无联网搜索结果,请如实说明无法回答,并建议用户提供更多信息或尝试其他方式。
'''
return ChatPromptTemplate.from_messages([
("system", system_template),

View File

@@ -9,7 +9,7 @@ warnings.filterwarnings("ignore", category=DeprecationWarning, module="websocket
warnings.filterwarnings("ignore", category=DeprecationWarning, module="uvicorn.protocols.websockets")
import os
from ..config import DB_URI, BACKEND_PORT
from backend.app.config import DB_URI, BACKEND_PORT
import uuid
import json
from contextlib import asynccontextmanager
@@ -22,18 +22,18 @@ from pydantic import BaseModel
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from .agent.agent_service import AIAgentService, create_serde
from .agent.history import ThreadHistoryService
from ..core.human_review import (
from backend.app.core.human_review import (
ReviewManager,
InMemoryReviewStore,
ReviewStatus,
HumanReview
)
from ..subgraphs.contact.api_client import ContactAPIClient
from ..subgraphs.dictionary.api_client import DictionaryAPIClient
from ..subgraphs.news_analysis.api_client import NewsAPIClient
from backend.app.subgraphs.contact.api_client import ContactAPIClient
from backend.app.subgraphs.dictionary.api_client import DictionaryAPIClient
from backend.app.subgraphs.news_analysis.api_client import NewsAPIClient
from .db.init_db import init_subgraph_tables
from .db.models import ContactRepository, DictionaryRepository, NewsRepository
from ..logger import info, error
from backend.app.logger import info, error
@asynccontextmanager
async def lifespan(app: FastAPI):

View File

@@ -21,15 +21,15 @@ from .nodes.fast_paths import (
fast_tool_node,
)
from .nodes.llm_call import create_dynamic_llm_call_node
from .nodes.rag_nodes import rag_retrieve_node
from .nodes.rag_nodes import rag_retrieve_node, check_rag_confidence
from .nodes.retrieve_memory import create_retrieve_memory_node
from .nodes.memory_trigger import memory_trigger_node, set_mem0_client
from .nodes.summarize import create_summarize_node
from .nodes.finalize import finalize_node
from ..subgraphs.contact import build_contact_subgraph
from ..subgraphs.dictionary import build_dictionary_subgraph
from ..subgraphs.news_analysis import build_news_analysis_subgraph
from ..logger import info
from backend.app.subgraphs.contact import build_contact_subgraph
from backend.app.subgraphs.dictionary import build_dictionary_subgraph
from backend.app.subgraphs.news_analysis import build_news_analysis_subgraph
from backend.app.logger import info
from .subgraph_wrapper import create_subgraph_nodes
@@ -198,8 +198,20 @@ def _add_react_loop_edges(graph: StateGraph, subgraph_nodes: Dict[str, Any]) ->
}
)
# RAG 检索后的置信度判断分支
graph.add_conditional_edges(
"rag_retrieve",
check_rag_confidence,
{
"high_confidence": "llm_call", # 高置信度 → 直接生成回答
"retry_rag": "rag_retrieve", # 低置信度 → 再次检索
"low_confidence": "web_search", # 两次RAG后仍低 → 联网搜索
"no_rag": "web_search", # 无结果 → 联网搜索
}
)
# 循环边(回到 react_reason
loop_back_nodes = ["rag_retrieve", "web_search", "handle_error"] + subgraph_names
loop_back_nodes = ["web_search", "handle_error"] + subgraph_names
for node_name in loop_back_nodes:
graph.add_edge(node_name, "react_reason")

View File

@@ -8,7 +8,7 @@ from .web_search import web_search_node
from .error_handling import error_handling_node
from .routing import init_state_node, route_by_reasoning, should_summarize
from .llm_call import create_dynamic_llm_call_node
from .rag_nodes import rag_retrieve_node, rag_re_retrieve_node
from .rag_nodes import rag_retrieve_node
# 记忆节点
from .retrieve_memory import create_retrieve_memory_node

View File

@@ -3,7 +3,7 @@
"""
from ...main_graph.state import MainGraphState, ErrorSeverity
from ...logger import info
from backend.app.logger import info
def error_handling_node(state: MainGraphState) -> MainGraphState:

View File

@@ -7,7 +7,7 @@ from typing import Optional
from langchain_core.runnables.config import RunnableConfig
from ..state import MainGraphState
from ...logger import info, debug
from backend.app.logger import info, debug
from ...model_services.chat_services import get_small_llm_service, get_chat_service
from .rag_nodes import rag_retrieve_node
from ._utils import dispatch_custom_event
@@ -113,10 +113,18 @@ async def fast_rag_node(state: MainGraphState, config: Optional[RunnableConfig]
def _has_valid_rag_results(state: MainGraphState) -> bool:
"""检查 RAG 结果是否有效"""
rag_docs = getattr(state, "rag_docs", [])
"""检查 RAG 结果是否有效(基于置信度)"""
from .rag_nodes import RAG_CONFIDENCE_THRESHOLD
rag_context = getattr(state, "rag_context", "")
return (rag_docs and len(rag_docs) > 0) or (rag_context and len(rag_context) > 10)
rag_confidence = getattr(state, "rag_confidence", 0.0)
# 有结果且置信度足够
has_content = rag_context and len(rag_context) > 0
has_confidence = rag_confidence >= RAG_CONFIDENCE_THRESHOLD
info(f"[Fast RAG Check] has_content={has_content}, rag_confidence={rag_confidence:.2f}, threshold={RAG_CONFIDENCE_THRESHOLD}")
return has_content and has_confidence
async def _generate_fast_answer(state: MainGraphState, query: str) -> MainGraphState:

View File

@@ -8,7 +8,7 @@ from typing import Any, Dict
# 本地模块
from ...main_graph.state import MainGraphState
from ...utils.logging import log_state_change
from ...logger import info, warning
from backend.app.logger import info, warning
from langchain_core.runnables.config import RunnableConfig

View File

@@ -11,7 +11,7 @@ from datetime import datetime
from langchain_core.runnables.config import RunnableConfig
from ..state import MainGraphState
from ...logger import info, debug
from backend.app.logger import info, debug
from ...model_services.chat_services import get_small_llm_service
from ._utils import dispatch_custom_event

View File

@@ -12,7 +12,7 @@ from langchain_core.messages import AIMessage
from ...main_graph.state import MainGraphState
from ...agent.prompts import create_system_prompt
from ...utils.logging import log_state_change
from ...logger import debug, info, error
from backend.app.logger import debug, info, error
def create_dynamic_llm_call_node(chat_services: Dict[str, BaseChatModel], tools: list):
@@ -115,24 +115,7 @@ def create_dynamic_llm_call_node(chat_services: Dict[str, BaseChatModel], tools:
):
chunks.append(chunk)
info(f"[llm_call] LLM astream 完成,共收到 {len(chunks)} 个 chunks")
for i, chunk in enumerate(chunks[:10]): # 只打印前10个避免日志过多
chunk_type = type(chunk).__name__
chunk_content = getattr(chunk, 'content', '') if hasattr(chunk, 'content') else str(chunk)
# 打印更多属性
additional_kwargs = getattr(chunk, 'additional_kwargs', {}) or {}
response_metadata = getattr(chunk, 'response_metadata', {}) or {}
# 打印所有属性
info(f"[llm_call] chunk[{i}] type={chunk_type}")
info(f"[llm_call] chunk[{i}] content长度={len(chunk_content) if chunk_content else 0}, content={repr(chunk_content[:200] if chunk_content else '')}")
info(f"[llm_call] chunk[{i}] additional_kwargs={additional_kwargs}")
info(f"[llm_call] chunk[{i}] response_metadata keys={list(response_metadata.keys()) if response_metadata else []}")
info(f"[llm_call] chunk[{i}] response_metadata={response_metadata}")
# 检查是否有其他可能存储内容的属性
if hasattr(chunk, 'tool_call_chunks'):
info(f"[llm_call] chunk[{i}] tool_call_chunks={chunk.tool_call_chunks}")
if hasattr(chunk, 'usage_metadata'):
info(f"[llm_call] chunk[{i}] usage_metadata={chunk.usage_metadata}")
info(f"[llm_call] LLM astream 完成,共收到 {len(chunks)} 个 chunks,info:{chunks}")
# 将所有 chunk 合并成最终的 AIMessage
if chunks:

View File

@@ -2,7 +2,7 @@ from typing import Any, Dict
from langchain_core.runnables.config import RunnableConfig
from ...main_graph.state import MainGraphState
from ...memory.mem0_client import Mem0Client
from ...logger import info
from backend.app.logger import info
# 全局变量,在 GraphBuilder 中注入

View File

@@ -1,6 +1,6 @@
"""
RAG 检索节点模块
使用模块级变量管理 RAG 工具
包含RAG 检索、置信度判断、重检索等节点
"""
import time
@@ -11,10 +11,15 @@ from langchain_core.runnables.config import RunnableConfig
from ...main_graph.state import MainGraphState, ErrorRecord, ErrorSeverity
from ...main_graph.utils.retry_utils import RAG_RETRY_CONFIG
from ...logger import info
from backend.app.logger import info, debug
from ...model_services import get_small_llm_service
from ._utils import dispatch_custom_event, make_react_event
# 置信度阈值配置
RAG_CONFIDENCE_THRESHOLD = 0.6 # 低于此值认为检索不相关
def _get_rag_tool() -> Optional[callable]:
"""获取 RAG 工具"""
from backend.app.main_graph.utils.rag_initializer import get_rag_tool
@@ -36,43 +41,27 @@ async def _rag_retrieve_core(state: MainGraphState, rag_tool: callable) -> MainG
# 调用 RAG 工具
rag_context = await rag_tool.ainvoke(retrieval_query)
info(f"[RAG Core] 获取到 rag_context: {type(rag_context)}, 长度={len(rag_context) if rag_context else 0}")
info(f"[RAG Core] ========== RAG 返回的知识内容 ==========")
info(f"{rag_context[:500]}..." if len(rag_context) > 500 else rag_context)
info(f"[RAG Core] ========================================")
# 更新状态
state.rag_context = rag_context
state.rag_docs = [{"source": "rag_retrieval", "content": rag_context}]
state.rag_retrieved = True
state.success = True
state.rag_attempts = getattr(state, 'rag_attempts', 0) + 1
state.debug_info["rag_source"] = "tool"
info(f"[RAG Core] state.rag_docs 长度: {len(state.rag_docs)}")
return state
# ========== RAG 检索节点 ==========
async def rag_retrieve_node(state: MainGraphState, config: Optional[RunnableConfig] = None) -> MainGraphState:
"""RAG 检索节点:带超时和重试"""
"""RAG 检索节点:检索 + 置信度评估"""
state.current_phase = "rag_retrieving"
start_time = time.time()
last_error = None
# 获取 RAG 工具
rag_tool = _get_rag_tool()
if not rag_tool:
error_record = ErrorRecord(
error_type="RAGRetrievalError",
error_message="RAG 工具未初始化",
severity=ErrorSeverity.WARNING,
source="rag_retrieve_node",
timestamp=datetime.now().isoformat(),
retry_count=0,
max_retries=RAG_RETRY_CONFIG.max_retries,
)
state.errors.append(error_record)
state.current_error = error_record
state.current_phase = "error_handling"
info("[RAG] RAG 工具未初始化")
state.rag_confidence = 0.0
state.rag_retrieved = False
return state
await dispatch_custom_event(
@@ -81,99 +70,184 @@ async def rag_retrieve_node(state: MainGraphState, config: Optional[RunnableConf
config
)
for attempt in range(RAG_RETRY_CONFIG.max_retries + 1):
try:
result = await _rag_retrieve_core(state, rag_tool)
try:
state = await _rag_retrieve_core(state, rag_tool)
info(f"[RAG] 检索成功,上下文长度: {len(result.rag_context)} 字符")
# 评估置信度
confidence = await _evaluate_rag_confidence(state)
state.rag_confidence = confidence
state.debug_info["rag_retrieval"] = {
"attempt": attempt + 1,
"success": True,
"time": time.time() - start_time
}
info(f"[RAG] 检索完成,置信度={confidence:.2f}RAG尝试次数={state.rag_attempts}")
state.reasoning_history.append({
"step": state.reasoning_step,
"action": "RETRIEVE_RAG",
"confidence": 1.0,
"reasoning": "RAG 检索完成",
"timestamp": datetime.now().isoformat()
})
state.reasoning_history.append({
"step": state.reasoning_step,
"action": "RETRIEVE_RAG",
"confidence": confidence,
"reasoning": f"RAG 检索完成,置信度={confidence:.2f}",
"timestamp": datetime.now().isoformat()
})
doc_count = len(result.rag_docs) if result.rag_docs else 0
await dispatch_custom_event(
"react_reasoning",
make_react_event(state.reasoning_step, "rag_retrieve_complete", 1.0,
f"RAG 检索完成,找到 {doc_count} 条相关文档"),
config
)
await dispatch_custom_event(
"react_reasoning",
make_react_event(state.reasoning_step, "rag_retrieve_complete", confidence,
f"RAG 检索完成,置信度={confidence:.2f}"),
config
)
return result
except Exception as e:
last_error = e
if attempt >= RAG_RETRY_CONFIG.max_retries:
break
await dispatch_custom_event(
"react_reasoning",
make_react_event(state.reasoning_step, "rag_retrieve_retry", 1.0,
f"RAG 检索失败,第 {attempt + 1} 次重试..."),
config
)
delay = RAG_RETRY_CONFIG.base_delay * (2 ** attempt)
await asyncio.sleep(min(delay, RAG_RETRY_CONFIG.max_delay))
# 失败记录
state.reasoning_history.append({
"step": state.reasoning_step,
"action": "RETRIEVE_RAG",
"confidence": 0.0,
"reasoning": f"RAG 检索失败: {str(last_error) if last_error else '超时'}",
"timestamp": datetime.now().isoformat()
})
error_record = ErrorRecord(
error_type="RAGRetrievalError",
error_message=str(last_error) if last_error else "RAG 检索超时",
severity=ErrorSeverity.WARNING,
source="rag_retrieve_node",
timestamp=datetime.now().isoformat(),
retry_count=RAG_RETRY_CONFIG.max_retries,
max_retries=RAG_RETRY_CONFIG.max_retries,
)
state.errors.append(error_record)
state.current_error = error_record
state.current_phase = "error_handling"
await dispatch_custom_event(
"react_reasoning",
make_react_event(state.reasoning_step, "rag_retrieve_error", 1.0,
f"RAG 检索失败: {str(last_error)}"),
config
)
except Exception as e:
info(f"[RAG] 检索失败: {e}")
state.rag_confidence = 0.0
state.rag_retrieved = False
return state
async def rag_re_retrieve_node(state: MainGraphState, config: Optional[RunnableConfig] = None) -> MainGraphState:
"""重新检索节点"""
state.current_phase = "rag_re_retrieving"
async def _evaluate_rag_confidence(state: MainGraphState) -> float:
"""评估 RAG 检索结果置信度(综合向量相似度 + 重排分数 + 小模型判断)"""
query = state.user_query or ""
rag_context = state.rag_context or ""
state.debug_info["rag_re_retrieve"] = {
"original_retrieved": state.rag_retrieved,
"original_docs_count": len(state.rag_docs)
}
if not rag_context:
return 0.0
return await rag_retrieve_node(state, config)
# 方式1: 向量相似度(从 rag_docs 中获取)
embedding_score = _get_embedding_similarity(state, query)
info(f"[RAG Confidence] 向量相似度={embedding_score:.3f}")
# 方式2: 重排序分数(从 rag_docs 中获取)
rerank_score = _get_rerank_score(state)
info(f"[RAG Confidence] 重排分数={rerank_score:.3f}")
# 方式3: 小模型判断
llm_score = await _get_llm_score(state)
info(f"[RAG Confidence] LLM评估={llm_score:.3f}")
# 综合得分(加权平均)
# 向量相似度权重 0.3,重排权重 0.3LLM 权重 0.4
final_score = embedding_score * 0.3 + rerank_score * 0.3 + llm_score * 0.4
info(f"[RAG Confidence] 综合置信度={final_score:.3f} (embedding={embedding_score:.3f}*0.3 + rerank={rerank_score:.3f}*0.3 + llm={llm_score:.3f}*0.4)")
return final_score
def _get_embedding_similarity(state: MainGraphState) -> float:
"""从 rag_docs 中获取向量相似度分数"""
rag_docs = getattr(state, "rag_docs", [])
# 如果有多个文档,取最高分
scores = []
for doc in rag_docs:
if isinstance(doc, dict):
score = doc.get("score", 0.0)
# 向量相似度通常在 0-1 之间RRF 分数可能更高
# 归一化到 0-1
if score > 1.0:
score = min(score / 10.0, 1.0) # 假设 max 约 10
scores.append(score)
elif hasattr(doc, "metadata"):
score = doc.metadata.get("score", 0.0)
if score > 1.0:
score = min(score / 10.0, 1.0)
scores.append(score)
if scores:
# 取平均或最高分
return max(scores) # 使用最高分更准确
return 0.0
def _get_rerank_score(state: MainGraphState) -> float:
"""从 rag_docs 中获取重排序分数"""
rag_docs = getattr(state, "rag_docs", [])
# 重排分数通常在 0-1 之间
scores = []
for doc in rag_docs:
if isinstance(doc, dict):
score = doc.get("rerank_score", 0.0)
elif hasattr(doc, "metadata"):
score = doc.metadata.get("rerank_score", 0.0)
else:
score = 0.0
if score > 0:
scores.append(score)
if scores:
return max(scores) # 使用最高分
return 0.0
async def _get_llm_score(state: MainGraphState) -> float:
"""使用小模型评估检索结果相关性"""
query = state.user_query or ""
rag_context = state.rag_context or ""
try:
llm = get_small_llm_service()
prompt = f"""评估以下检索结果与用户问题的相关性,返回 0.0-1.0 的分数:
- 1.0 = 完全相关,能直接回答问题
- 0.5 = 部分相关,有一定参考价值
- 0.0 = 完全不相关,无法回答问题
用户问题:{query}
检索结果:{rag_context[:1500]}
只返回一个数字:"""
response = await llm.ainvoke(prompt)
content = response.content.strip()
import re
match = re.search(r'(\d+\.?\d*)', content)
if match:
score = float(match.group(1))
return max(0.0, min(1.0, score))
except Exception as e:
info(f"[RAG Confidence] LLM评估失败: {e}")
return 0.5 # 默认中等置信度
# ========== 置信度判断节点 ==========
def check_rag_confidence(state: MainGraphState) -> str:
"""
根据 RAG 置信度判断下一步
Returns:
"high_confidence" - 高置信度(>=0.6),可直接生成回答
"low_confidence" - 低置信度(<0.6),需要联网搜索
"no_rag" - 无检索结果,需要联网搜索
"""
rag_attempts = getattr(state, 'rag_attempts', 0)
rag_confidence = getattr(state, 'rag_confidence', 0.0)
info(f"[Confidence Check] rag_attempts={rag_attempts}, rag_confidence={rag_confidence:.2f}")
# 情况1: 没有检索结果
if not getattr(state, 'rag_retrieved', False) or not state.rag_context:
info("[Confidence Check] 无检索结果,走联网")
return "no_rag"
# 情况2: 置信度低于阈值
if rag_confidence < RAG_CONFIDENCE_THRESHOLD:
if rag_attempts >= 2:
info(f"[Confidence Check] 置信度={rag_confidence:.2f}<{RAG_CONFIDENCE_THRESHOLD}且RAG尝试{rag_attempts}次,走联网")
return "low_confidence"
else:
info(f"[Confidence Check] 置信度={rag_confidence:.2f}<{RAG_CONFIDENCE_THRESHOLD}可再尝试RAG一次")
return "retry_rag"
# 情况3: 高置信度
info(f"[Confidence Check] 高置信度={rag_confidence:.2f}>={RAG_CONFIDENCE_THRESHOLD},直接生成回答")
return "high_confidence"
# ========== 导出 ==========
__all__ = [
"rag_retrieve_node",
"rag_re_retrieve_node",
"check_rag_confidence",
"RAG_CONFIDENCE_THRESHOLD",
]

View File

@@ -7,9 +7,9 @@ from typing import Optional
from datetime import datetime
from langchain_core.runnables.config import RunnableConfig
from ...core.intent import react_reason_async, ReasoningResult
from backend.app.core.intent import react_reason_async, ReasoningResult
from ...main_graph.state import MainGraphState
from ...logger import info
from backend.app.logger import info
from ._utils import dispatch_custom_event, make_react_event

View File

@@ -9,7 +9,7 @@ from typing import Any, Dict
from ...main_graph.state import MainGraphState
from ...memory.mem0_client import Mem0Client
from ...utils.logging import log_state_change
from ...logger import debug
from backend.app.logger import debug
def create_retrieve_memory_node(mem0_client: Mem0Client):

View File

@@ -10,9 +10,9 @@
from datetime import datetime
from ...core.intent import get_route_by_reasoning, ReasoningAction
from backend.app.core.intent import get_route_by_reasoning, ReasoningAction
from ...main_graph.state import MainGraphState
from ...logger import info
from backend.app.logger import info
# ========== 初始化状态节点 ==========
@@ -97,6 +97,12 @@ def route_by_reasoning(state: MainGraphState) -> str:
}
target = route_mapping.get(route, "llm_call")
# ========== RAG 次数硬限制 ==========
rag_attempts = getattr(state, 'rag_attempts', 0)
if target == "rag_retrieve" and rag_attempts >= 2:
info(f"[条件路由] RAG已尝试{rag_attempts}次,强制走联网搜索")
target = "web_search"
# ========== 循环防护检测 ==========
# 1. 路由模式检测A→B→A→B 交替)
if len(previous_actions) >= 4:

View File

@@ -9,7 +9,7 @@ from typing import Any, Dict
from ...main_graph.state import MainGraphState
from ...memory.mem0_client import Mem0Client
from ...utils.logging import log_state_change
from ...logger import debug, info, error, warning
from backend.app.logger import debug, info, error, warning
def create_summarize_node(mem0_client: Mem0Client):

View File

@@ -11,7 +11,7 @@ from ...main_graph.config import get_stream_writer
# 本地模块
from ...main_graph.state import MainGraphState
from ...utils.logging import log_state_change
from ...logger import debug, info
from backend.app.logger import debug, info
def create_tool_call_node(tools_by_name: Dict[str, Any]):
"""

View File

@@ -7,7 +7,7 @@ from datetime import datetime
from langchain_core.runnables.config import RunnableConfig
from ...main_graph.state import MainGraphState, ErrorRecord, ErrorSeverity
from ...logger import info
from backend.app.logger import info
async def web_search_node(state: MainGraphState, config: Optional[RunnableConfig] = None) -> MainGraphState:

View File

@@ -75,6 +75,8 @@ class MainGraphState:
rag_context: str = ""
rag_retrieved: bool = False
rag_docs: List[Dict[str, Any]] = field(default_factory=list)
rag_confidence: float = 0.0 # RAG 检索置信度 (0.0-1.0)
rag_attempts: int = 0 # RAG 检索次数统计
# ========== 联网搜索相关字段 ==========
web_search_results: List[str] = field(default_factory=list)

View File

@@ -8,7 +8,7 @@ from datetime import datetime
from langchain_core.runnables.config import RunnableConfig
from .state import MainGraphState, ErrorRecord, ErrorSeverity
from ..logger import info
from backend.app.logger import info
def wrap_subgraph_for_error_handling(subgraph, name: str):

View File

@@ -2,7 +2,7 @@
from ...rag.tools import create_rag_tool
from ...rag.retriever import create_parent_hybrid_retriever
from ...model_services import get_embedding_service
from ...logger import info, warning
from backend.app.logger import info, warning
import sys
# 全局 RAG 工具

View File

@@ -9,7 +9,7 @@ from typing import Optional, List
from mem0 import AsyncMemory
from ..config import (
from backend.app.config import (
LLM_API_KEY,
ZHIPUAI_API_KEY,
VLLM_BASE_URL,
@@ -21,7 +21,7 @@ from ..config import (
ZHIPU_EMBEDDING_MODEL,
ZHIPU_API_BASE,
)
from ..logger import info, warning, error
from backend.app.logger import info, warning, error
from ..model_services import get_embedding_service
from ..model_services.chat_services import get_chat_service

View File

@@ -23,7 +23,7 @@ from .base import (
FallbackServiceChain,
SingletonServiceManager
)
from ..config import (
from backend.app.config import (
VLLM_BASE_URL,
LLM_API_KEY,
ZHIPUAI_API_KEY,
@@ -203,7 +203,7 @@ class LocalSmallModelProvider(BaseServiceProvider[BaseChatModel]):
"""
def __init__(self, model: str = None):
from ..config import SMALL_LOCAL_MODEL_NAME, SMALL_VLLM_BASE_URL, SMALL_LLM_API_KEY
from backend.app.config import SMALL_LOCAL_MODEL_NAME, SMALL_VLLM_BASE_URL, SMALL_LLM_API_KEY
super().__init__("local_small")
self._model = model or SMALL_LOCAL_MODEL_NAME
self._base_url = SMALL_VLLM_BASE_URL
@@ -242,7 +242,7 @@ class DeepSeekSmallModelProvider(BaseServiceProvider[BaseChatModel]):
"""
def __init__(self, model: str = None):
from ..config import SMALL_DEEPSEEK_MODEL, SMALL_DEEPSEEK_API_KEY, SMALL_DEEPSEEK_API_BASE
from backend.app.config import SMALL_DEEPSEEK_MODEL, SMALL_DEEPSEEK_API_KEY, SMALL_DEEPSEEK_API_BASE
super().__init__("deepseek_small")
self._model = model or SMALL_DEEPSEEK_MODEL
self._api_key = SMALL_DEEPSEEK_API_KEY

View File

@@ -21,7 +21,7 @@ from .base import (
FallbackServiceChain,
SingletonServiceManager
)
from ..config import (
from backend.app.config import (
LLAMACPP_EMBEDDING_URL,
LLAMACPP_API_KEY,
ZHIPUAI_API_KEY,

View File

@@ -27,7 +27,7 @@ from .base import (
FallbackServiceChain,
SingletonServiceManager
)
from ..config import (
from backend.app.config import (
LLAMACPP_RERANKER_URL,
LLAMACPP_API_KEY,
ZHIPUAI_API_KEY,

View File

@@ -81,11 +81,17 @@ class RAGPipeline:
return await self.retriever.ainvoke(query)
async def _get_parents(self, child_docs: List[Document]) -> List[Document]:
parent_map = {}
# 收集 parent_id 和对应的分数
parent_map = {} # parent_id -> (embedding_score, rerank_score)
for doc in child_docs:
pid = doc.metadata.get("parent_id")
if pid and pid not in parent_map:
parent_map[pid] = doc.metadata.get("score", 0.0)
# embedding 分数
embedding_score = doc.metadata.get("score", 0.0)
# rerank 分数(如果有的话)
rerank_score = doc.metadata.get("rerank_score", 0.0)
parent_map[pid] = (embedding_score, rerank_score)
if not parent_map:
logger.warning("[Pipeline] 未找到 parent_id返回子文档")
@@ -94,10 +100,19 @@ class RAGPipeline:
try:
from backend.rag_core import create_docstore
docstore, _ = create_docstore()
# 同步获取(异步版本不存在)
parent_docs = docstore.mget(list(parent_map.keys()))
parent_map2 = {d.metadata.get("id"): d for d in parent_docs if d}
result = [(parent_map2[pid], score) for pid, score in parent_map.items() if pid in parent_map2]
# 构建结果,保持分数信息
result = []
for doc in parent_docs:
if doc:
pid = doc.metadata.get("id")
scores = parent_map.get(pid, (0.0, 0.0))
# 将分数添加到 metadata 中
doc.metadata["embedding_score"] = scores[0]
doc.metadata["rerank_score"] = scores[1]
result.append((doc, scores[0] + scores[1] * 2)) # 综合分数rerank 权重更高
result.sort(key=lambda x: x[1], reverse=True)
docs = [d for d, _ in result]
logger.info(f"[Pipeline] 获取到 {len(docs)} 个父文档")

View File

@@ -49,44 +49,38 @@ class DocumentReranker:
top_n: 返回前 N 个结果
Returns:
List[Document]: 排序后的文档列表
List[Document]: 排序后的文档列表,每个文档的 metadata 中包含 rerank_score
"""
if not documents:
return []
try:
# 1. 从 Document 提取内容(业务逻辑)
# 1. 从 Document 提取内容
doc_contents = [doc.page_content for doc in documents]
logger.info(f"[Rerank] 收到 {len(documents)} 个文档待重排, query={query[:50]}")
total_chars = sum(len(c) for c in doc_contents)
logger.info(f"[Rerank] 各文档长度: {[len(c) for c in doc_contents]}, 总字符数: {total_chars}")
# 粗略估算 tokens (中文约 0.75 tokens/字符)
estimated_tokens = int(total_chars * 0.75)
logger.info(f"[Rerank] 估算总 tokens: ~{estimated_tokens} (假设中文)")
logger.info(f"[Rerank] 收到 {len(documents)} 个文档待重排")
# 2. 调用服务计算得分
logger.info(f"[Rerank] 正在调用 rerank service: {type(self._rerank_service).__name__}")
# 2. 调用重排服务计算得分
scores = self._rerank_service.compute_scores(query, doc_contents)
logger.info(f"[Rerank] 获取到 {len(scores)} 个得分: {scores}")
logger.info(f"[Rerank] 获取到 {len(scores)} 个得分")
# 3. 根据得分排序(业务逻辑)
# 3. 构建 (文档, 分数) 对并排序
doc_score_pairs = list(zip(documents, scores))
doc_score_pairs_sorted = sorted(doc_score_pairs, key=lambda x: x[1], reverse=True)
logger.info(f"[Rerank] 排序后的结果:")
for i, (doc, score) in enumerate(doc_score_pairs_sorted):
logger.info(f" [{i}] score={score:.4f}, content={doc.page_content[:80]}...")
# 4. 取 top_n
top_docs = [pair[0] for pair in doc_score_pairs_sorted[:top_n]]
# 4. 取 top_n并添加 rerank_score 到 metadata
top_docs = []
for doc, score in doc_score_pairs_sorted[:top_n]:
# 创建新文档,添加 rerank_score
new_doc = Document(
page_content=doc.page_content,
metadata={**doc.metadata, "rerank_score": score}
)
top_docs.append(new_doc)
return top_docs
except Exception as e:
logger.warning(f"重排过程出错,返回原始前 {top_n}结果: {e}")
logger.warning(f"[Rerank] 异常详情: {type(e).__name__}: {e}")
import traceback
logger.warning(f"[Rerank] 堆栈: {traceback.format_exc()}")
logger.warning(f"[Rerank] 重排失败,返回原始结果: {e}")
return documents[:top_n]

View File

@@ -22,7 +22,7 @@ from pydantic import Field, PrivateAttr
from backend.rag_core import QdrantHybridStore, get_sparse_embedder, create_docstore
from backend.rag_core.client import create_async_qdrant_client
from ..model_services import get_embedding_service
from ..logger import info, warning, debug
from backend.app.logger import info, warning, debug
# 模块级常量

View File

@@ -8,7 +8,7 @@ from typing import Dict, Any
from datetime import datetime
# 公共工具
from ...core import MarkdownFormatter
from backend.app.core import MarkdownFormatter
from .state import ContactState
from .api_client import ContactAPIClient

View File

@@ -8,7 +8,7 @@ from datetime import datetime
import random
# 公共工具
from ...core import (
from backend.app.core import (
MarkdownFormatter
)

View File

@@ -7,7 +7,7 @@ from typing import Dict, Any
from datetime import datetime
# 公共工具
from ...core import MarkdownFormatter
from backend.app.core import MarkdownFormatter
from .state import (
NewsAnalysisState,

View File

@@ -3,8 +3,8 @@ LangGraph 节点日志工具模块
提供状态流转追踪和 LLM 输入输出打印功能
"""
from ..config import ENABLE_GRAPH_TRACE
from ..logger import debug, info
from backend.app.config import ENABLE_GRAPH_TRACE
from backend.app.logger import debug, info
from ..main_graph.state import MainGraphState

View File

@@ -69,7 +69,7 @@ def cleanup(signum, frame):
for i, proc in enumerate(processes):
if proc.poll() is None: # 进程还在运行
proc.terminate()
proc.wait(timeout=5)
proc.wait(timeout=1)
print(f"✓ 服务 {i+1} 已停止")
sys.exit(0)