添加rag置信度判断
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 6m31s

This commit is contained in:
2026-05-06 01:15:52 +08:00
parent 3ae9daa01a
commit 1260bef5cb
35 changed files with 335 additions and 221 deletions

View File

@@ -1481,7 +1481,7 @@ mkdir backend/app/subgraphs/my_subgraph
2. **创建状态定义 (state.py)** 2. **创建状态定义 (state.py)**
```python ```python
from typing_extensions import TypedDict from typing_extensions import TypedDict
from ..core.state_base import BaseSubgraphState from backend.app.core.state_base import BaseSubgraphState
class MySubgraphState(BaseSubgraphState): class MySubgraphState(BaseSubgraphState):
\"\"\" \"\"\"

View File

@@ -16,8 +16,8 @@ from ..main_graph.main_graph_builder import build_react_main_graph
from ..main_graph.tools.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME from ..main_graph.tools.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
from ..main_graph.config import set_stream_writer from ..main_graph.config import set_stream_writer
from ..main_graph.utils.rag_initializer import init_rag_tool from ..main_graph.utils.rag_initializer import init_rag_tool
from ..core.intent_classifier import get_intent_classifier from backend.app.core.intent_classifier import get_intent_classifier
from ..logger import debug, info, warning, error from backend.app.logger import debug, info, warning, error
from ..main_graph.state import MainGraphState, CurrentAction from ..main_graph.state import MainGraphState, CurrentAction

View File

@@ -4,7 +4,7 @@
""" """
from typing import List, Dict, Any from typing import List, Dict, Any
from ..logger import error # 保持兼容,或者替换为 logger from backend.app.logger import error # 保持兼容,或者替换为 logger
class ThreadHistoryService: class ThreadHistoryService:
"""线程历史查询服务""" """线程历史查询服务"""

View File

@@ -3,9 +3,10 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
def create_system_prompt(tools: list = None) -> ChatPromptTemplate: def create_system_prompt(tools: list = None) -> ChatPromptTemplate:
""" """
创建系统提示模板,可选择动态注入工具描述 创建系统提示模板,整合多子系统能力、检索策略与回答规范
""" """
tools_section = "" # 构造工具描述
tools_section = "无可用工具"
if tools: if tools:
tool_descs = [] tool_descs = []
for tool in tools: for tool in tools:
@@ -14,25 +15,44 @@ def create_system_prompt(tools: list = None) -> ChatPromptTemplate:
tool_descs.append(f"- {name}: {desc}") tool_descs.append(f"- {name}: {desc}")
tools_section = "\n".join(tool_descs) tools_section = "\n".join(tool_descs)
system_template = ( # 使用 f-string 将 tools_section 直接嵌入,而 memory_context 用双花括号转义保留为变量
"你是一个智能助手,具有三个专业子系统和RAG检索能力请使用中文交流。\n\n" system_template = f'''你是一个智能助手,具备以下专业子系统和检索能力请使用中文交流。
"【核心功能】\n"
"1. 📚 词典/翻译子系统 - 查询单词、翻译文本、提取术语、每日一词\n" ## 核心功能
"2. 📰 资讯分析子系统 - 查询新闻、分析URL、提取关键词、生成报告\n" 1. 📚 词典/翻译子系统 查询单词、翻译文本、提取术语、每日一词
"3. 📇 通讯录子系统 - 查询联系人、添加联系人、管理通讯录\n" 2. 📰 资讯分析子系统 查询新闻、分析URL、提取关键词、生成报告
"4. 🔍 RAG检索 - 从知识库中检索相关信息回答问题\n\n" 3. 📇 通讯录子系统 查询联系人、添加联系人、管理通讯录
"【用户背景信息】\n" 4. 🔍 RAG检索 从知识库中检索相关信息回答问题
"以下是对当前用户的已知信息和长期记忆,你必须优先采纳:\n"
"{memory_context}\n" ## 检索与信息获取策略
"【可用工具与使用规则】\n" 当收到用户问题时,请按以下优先级处理:
f"{tools_section}\n" 1. **RAG 检索第1次**:首先尝试从知识库中查找答案。
"工具调用时请直接返回所需参数,无需额外说明。\n\n" 2. **ReRAG第2次优化检索**:如果第一次检索结果不相关或不充分,可以优化查询后再次进行 RAG 检索。
"【回答要求(必须遵守)】\n" 3. **联网搜索**:如果两次 RAG 检索后仍无法获得满意答案,必须使用联网搜索获取最新信息。
"1. 回答必须简洁、直接。\n" **重要约束**
"2. 如果你认为该问题需要进行深入的推理或思考,请务必将你的思维链或推理过程用 `<think>` 和 `</think>` 标签包裹起来,放在回答的最前面。\n" - 最多进行 **2 次** RAG 检索尝试。
"3. 优先利用已知用户信息进行个性化回复。\n" - 第3次决定获取信息时必须选择**联网搜索**,禁止无休止的本地检索。
"4. 若无信息可依,礼貌询问或提供通用帮助。" - 如果已经明确知识库不包含该信息(例如用户询问实时新闻),可以直接进入联网搜索。
)
## 可用工具
{tools_section}
工具调用时请直接返回所需参数,无需额外说明。
## 用户背景信息
以下是当前用户的已知信息和长期记忆,你应在回答中优先利用这些信息进行个性化回复:
{{memory_context}}
若无相关信息,可礼貌询问或提供通用帮助。
## 回答要求(必须严格遵守)
1. **来源标注**:回答开头必须明确标注信息来源,格式如下:
- 使用知识库时:`【知识库:来源描述】`
- 使用联网搜索时:`【联网搜索:来源描述】`
- 若同时用到多个来源,按实际使用顺序标注,例如:`【知识库:三国演义】【联网搜索:百度百科】`
2. **思维链**:如果问题需要深入推理或复杂思考,请将推理过程用 `<think>...</think>` 标签包裹,放在回答最前面(来源标注之前)。
3. **简洁直接**:回答应重点突出、条理清晰,避免冗长。
4. **个性化**:结合用户信息进行针对性回复。
5. **无依据时**:若既无知识库支撑也无联网搜索结果,请如实说明无法回答,并建议用户提供更多信息或尝试其他方式。
'''
return ChatPromptTemplate.from_messages([ return ChatPromptTemplate.from_messages([
("system", system_template), ("system", system_template),

View File

@@ -9,7 +9,7 @@ warnings.filterwarnings("ignore", category=DeprecationWarning, module="websocket
warnings.filterwarnings("ignore", category=DeprecationWarning, module="uvicorn.protocols.websockets") warnings.filterwarnings("ignore", category=DeprecationWarning, module="uvicorn.protocols.websockets")
import os import os
from ..config import DB_URI, BACKEND_PORT from backend.app.config import DB_URI, BACKEND_PORT
import uuid import uuid
import json import json
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
@@ -22,18 +22,18 @@ from pydantic import BaseModel
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from .agent.agent_service import AIAgentService, create_serde from .agent.agent_service import AIAgentService, create_serde
from .agent.history import ThreadHistoryService from .agent.history import ThreadHistoryService
from ..core.human_review import ( from backend.app.core.human_review import (
ReviewManager, ReviewManager,
InMemoryReviewStore, InMemoryReviewStore,
ReviewStatus, ReviewStatus,
HumanReview HumanReview
) )
from ..subgraphs.contact.api_client import ContactAPIClient from backend.app.subgraphs.contact.api_client import ContactAPIClient
from ..subgraphs.dictionary.api_client import DictionaryAPIClient from backend.app.subgraphs.dictionary.api_client import DictionaryAPIClient
from ..subgraphs.news_analysis.api_client import NewsAPIClient from backend.app.subgraphs.news_analysis.api_client import NewsAPIClient
from .db.init_db import init_subgraph_tables from .db.init_db import init_subgraph_tables
from .db.models import ContactRepository, DictionaryRepository, NewsRepository from .db.models import ContactRepository, DictionaryRepository, NewsRepository
from ..logger import info, error from backend.app.logger import info, error
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):

View File

@@ -21,15 +21,15 @@ from .nodes.fast_paths import (
fast_tool_node, fast_tool_node,
) )
from .nodes.llm_call import create_dynamic_llm_call_node from .nodes.llm_call import create_dynamic_llm_call_node
from .nodes.rag_nodes import rag_retrieve_node from .nodes.rag_nodes import rag_retrieve_node, check_rag_confidence
from .nodes.retrieve_memory import create_retrieve_memory_node from .nodes.retrieve_memory import create_retrieve_memory_node
from .nodes.memory_trigger import memory_trigger_node, set_mem0_client from .nodes.memory_trigger import memory_trigger_node, set_mem0_client
from .nodes.summarize import create_summarize_node from .nodes.summarize import create_summarize_node
from .nodes.finalize import finalize_node from .nodes.finalize import finalize_node
from ..subgraphs.contact import build_contact_subgraph from backend.app.subgraphs.contact import build_contact_subgraph
from ..subgraphs.dictionary import build_dictionary_subgraph from backend.app.subgraphs.dictionary import build_dictionary_subgraph
from ..subgraphs.news_analysis import build_news_analysis_subgraph from backend.app.subgraphs.news_analysis import build_news_analysis_subgraph
from ..logger import info from backend.app.logger import info
from .subgraph_wrapper import create_subgraph_nodes from .subgraph_wrapper import create_subgraph_nodes
@@ -198,8 +198,20 @@ def _add_react_loop_edges(graph: StateGraph, subgraph_nodes: Dict[str, Any]) ->
} }
) )
# RAG 检索后的置信度判断分支
graph.add_conditional_edges(
"rag_retrieve",
check_rag_confidence,
{
"high_confidence": "llm_call", # 高置信度 → 直接生成回答
"retry_rag": "rag_retrieve", # 低置信度 → 再次检索
"low_confidence": "web_search", # 两次RAG后仍低 → 联网搜索
"no_rag": "web_search", # 无结果 → 联网搜索
}
)
# 循环边(回到 react_reason # 循环边(回到 react_reason
loop_back_nodes = ["rag_retrieve", "web_search", "handle_error"] + subgraph_names loop_back_nodes = ["web_search", "handle_error"] + subgraph_names
for node_name in loop_back_nodes: for node_name in loop_back_nodes:
graph.add_edge(node_name, "react_reason") graph.add_edge(node_name, "react_reason")

View File

@@ -8,7 +8,7 @@ from .web_search import web_search_node
from .error_handling import error_handling_node from .error_handling import error_handling_node
from .routing import init_state_node, route_by_reasoning, should_summarize from .routing import init_state_node, route_by_reasoning, should_summarize
from .llm_call import create_dynamic_llm_call_node from .llm_call import create_dynamic_llm_call_node
from .rag_nodes import rag_retrieve_node, rag_re_retrieve_node from .rag_nodes import rag_retrieve_node
# 记忆节点 # 记忆节点
from .retrieve_memory import create_retrieve_memory_node from .retrieve_memory import create_retrieve_memory_node

View File

@@ -3,7 +3,7 @@
""" """
from ...main_graph.state import MainGraphState, ErrorSeverity from ...main_graph.state import MainGraphState, ErrorSeverity
from ...logger import info from backend.app.logger import info
def error_handling_node(state: MainGraphState) -> MainGraphState: def error_handling_node(state: MainGraphState) -> MainGraphState:

View File

@@ -7,7 +7,7 @@ from typing import Optional
from langchain_core.runnables.config import RunnableConfig from langchain_core.runnables.config import RunnableConfig
from ..state import MainGraphState from ..state import MainGraphState
from ...logger import info, debug from backend.app.logger import info, debug
from ...model_services.chat_services import get_small_llm_service, get_chat_service from ...model_services.chat_services import get_small_llm_service, get_chat_service
from .rag_nodes import rag_retrieve_node from .rag_nodes import rag_retrieve_node
from ._utils import dispatch_custom_event from ._utils import dispatch_custom_event
@@ -113,10 +113,18 @@ async def fast_rag_node(state: MainGraphState, config: Optional[RunnableConfig]
def _has_valid_rag_results(state: MainGraphState) -> bool: def _has_valid_rag_results(state: MainGraphState) -> bool:
"""检查 RAG 结果是否有效""" """检查 RAG 结果是否有效(基于置信度)"""
rag_docs = getattr(state, "rag_docs", []) from .rag_nodes import RAG_CONFIDENCE_THRESHOLD
rag_context = getattr(state, "rag_context", "") rag_context = getattr(state, "rag_context", "")
return (rag_docs and len(rag_docs) > 0) or (rag_context and len(rag_context) > 10) rag_confidence = getattr(state, "rag_confidence", 0.0)
# 有结果且置信度足够
has_content = rag_context and len(rag_context) > 0
has_confidence = rag_confidence >= RAG_CONFIDENCE_THRESHOLD
info(f"[Fast RAG Check] has_content={has_content}, rag_confidence={rag_confidence:.2f}, threshold={RAG_CONFIDENCE_THRESHOLD}")
return has_content and has_confidence
async def _generate_fast_answer(state: MainGraphState, query: str) -> MainGraphState: async def _generate_fast_answer(state: MainGraphState, query: str) -> MainGraphState:

View File

@@ -8,7 +8,7 @@ from typing import Any, Dict
# 本地模块 # 本地模块
from ...main_graph.state import MainGraphState from ...main_graph.state import MainGraphState
from ...utils.logging import log_state_change from ...utils.logging import log_state_change
from ...logger import info, warning from backend.app.logger import info, warning
from langchain_core.runnables.config import RunnableConfig from langchain_core.runnables.config import RunnableConfig

View File

@@ -11,7 +11,7 @@ from datetime import datetime
from langchain_core.runnables.config import RunnableConfig from langchain_core.runnables.config import RunnableConfig
from ..state import MainGraphState from ..state import MainGraphState
from ...logger import info, debug from backend.app.logger import info, debug
from ...model_services.chat_services import get_small_llm_service from ...model_services.chat_services import get_small_llm_service
from ._utils import dispatch_custom_event from ._utils import dispatch_custom_event

View File

@@ -12,7 +12,7 @@ from langchain_core.messages import AIMessage
from ...main_graph.state import MainGraphState from ...main_graph.state import MainGraphState
from ...agent.prompts import create_system_prompt from ...agent.prompts import create_system_prompt
from ...utils.logging import log_state_change from ...utils.logging import log_state_change
from ...logger import debug, info, error from backend.app.logger import debug, info, error
def create_dynamic_llm_call_node(chat_services: Dict[str, BaseChatModel], tools: list): def create_dynamic_llm_call_node(chat_services: Dict[str, BaseChatModel], tools: list):
@@ -115,24 +115,7 @@ def create_dynamic_llm_call_node(chat_services: Dict[str, BaseChatModel], tools:
): ):
chunks.append(chunk) chunks.append(chunk)
info(f"[llm_call] LLM astream 完成,共收到 {len(chunks)} 个 chunks") info(f"[llm_call] LLM astream 完成,共收到 {len(chunks)} 个 chunks,info:{chunks}")
for i, chunk in enumerate(chunks[:10]): # 只打印前10个避免日志过多
chunk_type = type(chunk).__name__
chunk_content = getattr(chunk, 'content', '') if hasattr(chunk, 'content') else str(chunk)
# 打印更多属性
additional_kwargs = getattr(chunk, 'additional_kwargs', {}) or {}
response_metadata = getattr(chunk, 'response_metadata', {}) or {}
# 打印所有属性
info(f"[llm_call] chunk[{i}] type={chunk_type}")
info(f"[llm_call] chunk[{i}] content长度={len(chunk_content) if chunk_content else 0}, content={repr(chunk_content[:200] if chunk_content else '')}")
info(f"[llm_call] chunk[{i}] additional_kwargs={additional_kwargs}")
info(f"[llm_call] chunk[{i}] response_metadata keys={list(response_metadata.keys()) if response_metadata else []}")
info(f"[llm_call] chunk[{i}] response_metadata={response_metadata}")
# 检查是否有其他可能存储内容的属性
if hasattr(chunk, 'tool_call_chunks'):
info(f"[llm_call] chunk[{i}] tool_call_chunks={chunk.tool_call_chunks}")
if hasattr(chunk, 'usage_metadata'):
info(f"[llm_call] chunk[{i}] usage_metadata={chunk.usage_metadata}")
# 将所有 chunk 合并成最终的 AIMessage # 将所有 chunk 合并成最终的 AIMessage
if chunks: if chunks:

View File

@@ -2,7 +2,7 @@ from typing import Any, Dict
from langchain_core.runnables.config import RunnableConfig from langchain_core.runnables.config import RunnableConfig
from ...main_graph.state import MainGraphState from ...main_graph.state import MainGraphState
from ...memory.mem0_client import Mem0Client from ...memory.mem0_client import Mem0Client
from ...logger import info from backend.app.logger import info
# 全局变量,在 GraphBuilder 中注入 # 全局变量,在 GraphBuilder 中注入

View File

@@ -1,6 +1,6 @@
""" """
RAG 检索节点模块 RAG 检索节点模块
使用模块级变量管理 RAG 工具 包含RAG 检索、置信度判断、重检索等节点
""" """
import time import time
@@ -11,10 +11,15 @@ from langchain_core.runnables.config import RunnableConfig
from ...main_graph.state import MainGraphState, ErrorRecord, ErrorSeverity from ...main_graph.state import MainGraphState, ErrorRecord, ErrorSeverity
from ...main_graph.utils.retry_utils import RAG_RETRY_CONFIG from ...main_graph.utils.retry_utils import RAG_RETRY_CONFIG
from ...logger import info from backend.app.logger import info, debug
from ...model_services import get_small_llm_service
from ._utils import dispatch_custom_event, make_react_event from ._utils import dispatch_custom_event, make_react_event
# 置信度阈值配置
RAG_CONFIDENCE_THRESHOLD = 0.6 # 低于此值认为检索不相关
def _get_rag_tool() -> Optional[callable]: def _get_rag_tool() -> Optional[callable]:
"""获取 RAG 工具""" """获取 RAG 工具"""
from backend.app.main_graph.utils.rag_initializer import get_rag_tool from backend.app.main_graph.utils.rag_initializer import get_rag_tool
@@ -36,43 +41,27 @@ async def _rag_retrieve_core(state: MainGraphState, rag_tool: callable) -> MainG
# 调用 RAG 工具 # 调用 RAG 工具
rag_context = await rag_tool.ainvoke(retrieval_query) rag_context = await rag_tool.ainvoke(retrieval_query)
info(f"[RAG Core] 获取到 rag_context: {type(rag_context)}, 长度={len(rag_context) if rag_context else 0}") info(f"[RAG Core] 获取到 rag_context: {type(rag_context)}, 长度={len(rag_context) if rag_context else 0}")
info(f"[RAG Core] ========== RAG 返回的知识内容 ==========")
info(f"{rag_context[:500]}..." if len(rag_context) > 500 else rag_context)
info(f"[RAG Core] ========================================")
# 更新状态
state.rag_context = rag_context state.rag_context = rag_context
state.rag_docs = [{"source": "rag_retrieval", "content": rag_context}]
state.rag_retrieved = True state.rag_retrieved = True
state.success = True state.rag_attempts = getattr(state, 'rag_attempts', 0) + 1
state.debug_info["rag_source"] = "tool" state.debug_info["rag_source"] = "tool"
info(f"[RAG Core] state.rag_docs 长度: {len(state.rag_docs)}")
return state return state
# ========== RAG 检索节点 ========== # ========== RAG 检索节点 ==========
async def rag_retrieve_node(state: MainGraphState, config: Optional[RunnableConfig] = None) -> MainGraphState: async def rag_retrieve_node(state: MainGraphState, config: Optional[RunnableConfig] = None) -> MainGraphState:
"""RAG 检索节点:带超时和重试""" """RAG 检索节点:检索 + 置信度评估"""
state.current_phase = "rag_retrieving" state.current_phase = "rag_retrieving"
start_time = time.time() start_time = time.time()
last_error = None
# 获取 RAG 工具
rag_tool = _get_rag_tool() rag_tool = _get_rag_tool()
if not rag_tool: if not rag_tool:
error_record = ErrorRecord( info("[RAG] RAG 工具未初始化")
error_type="RAGRetrievalError", state.rag_confidence = 0.0
error_message="RAG 工具未初始化", state.rag_retrieved = False
severity=ErrorSeverity.WARNING,
source="rag_retrieve_node",
timestamp=datetime.now().isoformat(),
retry_count=0,
max_retries=RAG_RETRY_CONFIG.max_retries,
)
state.errors.append(error_record)
state.current_error = error_record
state.current_phase = "error_handling"
return state return state
await dispatch_custom_event( await dispatch_custom_event(
@@ -81,99 +70,184 @@ async def rag_retrieve_node(state: MainGraphState, config: Optional[RunnableConf
config config
) )
for attempt in range(RAG_RETRY_CONFIG.max_retries + 1): try:
try: state = await _rag_retrieve_core(state, rag_tool)
result = await _rag_retrieve_core(state, rag_tool)
info(f"[RAG] 检索成功,上下文长度: {len(result.rag_context)} 字符") # 评估置信度
confidence = await _evaluate_rag_confidence(state)
state.rag_confidence = confidence
state.debug_info["rag_retrieval"] = { info(f"[RAG] 检索完成,置信度={confidence:.2f}RAG尝试次数={state.rag_attempts}")
"attempt": attempt + 1,
"success": True,
"time": time.time() - start_time
}
state.reasoning_history.append({ state.reasoning_history.append({
"step": state.reasoning_step, "step": state.reasoning_step,
"action": "RETRIEVE_RAG", "action": "RETRIEVE_RAG",
"confidence": 1.0, "confidence": confidence,
"reasoning": "RAG 检索完成", "reasoning": f"RAG 检索完成,置信度={confidence:.2f}",
"timestamp": datetime.now().isoformat() "timestamp": datetime.now().isoformat()
}) })
doc_count = len(result.rag_docs) if result.rag_docs else 0 await dispatch_custom_event(
await dispatch_custom_event( "react_reasoning",
"react_reasoning", make_react_event(state.reasoning_step, "rag_retrieve_complete", confidence,
make_react_event(state.reasoning_step, "rag_retrieve_complete", 1.0, f"RAG 检索完成,置信度={confidence:.2f}"),
f"RAG 检索完成,找到 {doc_count} 条相关文档"), config
config )
)
return result except Exception as e:
info(f"[RAG] 检索失败: {e}")
except Exception as e: state.rag_confidence = 0.0
last_error = e state.rag_retrieved = False
if attempt >= RAG_RETRY_CONFIG.max_retries:
break
await dispatch_custom_event(
"react_reasoning",
make_react_event(state.reasoning_step, "rag_retrieve_retry", 1.0,
f"RAG 检索失败,第 {attempt + 1} 次重试..."),
config
)
delay = RAG_RETRY_CONFIG.base_delay * (2 ** attempt)
await asyncio.sleep(min(delay, RAG_RETRY_CONFIG.max_delay))
# 失败记录
state.reasoning_history.append({
"step": state.reasoning_step,
"action": "RETRIEVE_RAG",
"confidence": 0.0,
"reasoning": f"RAG 检索失败: {str(last_error) if last_error else '超时'}",
"timestamp": datetime.now().isoformat()
})
error_record = ErrorRecord(
error_type="RAGRetrievalError",
error_message=str(last_error) if last_error else "RAG 检索超时",
severity=ErrorSeverity.WARNING,
source="rag_retrieve_node",
timestamp=datetime.now().isoformat(),
retry_count=RAG_RETRY_CONFIG.max_retries,
max_retries=RAG_RETRY_CONFIG.max_retries,
)
state.errors.append(error_record)
state.current_error = error_record
state.current_phase = "error_handling"
await dispatch_custom_event(
"react_reasoning",
make_react_event(state.reasoning_step, "rag_retrieve_error", 1.0,
f"RAG 检索失败: {str(last_error)}"),
config
)
return state return state
async def rag_re_retrieve_node(state: MainGraphState, config: Optional[RunnableConfig] = None) -> MainGraphState: async def _evaluate_rag_confidence(state: MainGraphState) -> float:
"""重新检索节点""" """评估 RAG 检索结果置信度(综合向量相似度 + 重排分数 + 小模型判断)"""
state.current_phase = "rag_re_retrieving" query = state.user_query or ""
rag_context = state.rag_context or ""
state.debug_info["rag_re_retrieve"] = { if not rag_context:
"original_retrieved": state.rag_retrieved, return 0.0
"original_docs_count": len(state.rag_docs)
}
return await rag_retrieve_node(state, config) # 方式1: 向量相似度(从 rag_docs 中获取)
embedding_score = _get_embedding_similarity(state, query)
info(f"[RAG Confidence] 向量相似度={embedding_score:.3f}")
# 方式2: 重排序分数(从 rag_docs 中获取)
rerank_score = _get_rerank_score(state)
info(f"[RAG Confidence] 重排分数={rerank_score:.3f}")
# 方式3: 小模型判断
llm_score = await _get_llm_score(state)
info(f"[RAG Confidence] LLM评估={llm_score:.3f}")
# 综合得分(加权平均)
# 向量相似度权重 0.3,重排权重 0.3LLM 权重 0.4
final_score = embedding_score * 0.3 + rerank_score * 0.3 + llm_score * 0.4
info(f"[RAG Confidence] 综合置信度={final_score:.3f} (embedding={embedding_score:.3f}*0.3 + rerank={rerank_score:.3f}*0.3 + llm={llm_score:.3f}*0.4)")
return final_score
def _get_embedding_similarity(state: MainGraphState) -> float:
"""从 rag_docs 中获取向量相似度分数"""
rag_docs = getattr(state, "rag_docs", [])
# 如果有多个文档,取最高分
scores = []
for doc in rag_docs:
if isinstance(doc, dict):
score = doc.get("score", 0.0)
# 向量相似度通常在 0-1 之间RRF 分数可能更高
# 归一化到 0-1
if score > 1.0:
score = min(score / 10.0, 1.0) # 假设 max 约 10
scores.append(score)
elif hasattr(doc, "metadata"):
score = doc.metadata.get("score", 0.0)
if score > 1.0:
score = min(score / 10.0, 1.0)
scores.append(score)
if scores:
# 取平均或最高分
return max(scores) # 使用最高分更准确
return 0.0
def _get_rerank_score(state: MainGraphState) -> float:
"""从 rag_docs 中获取重排序分数"""
rag_docs = getattr(state, "rag_docs", [])
# 重排分数通常在 0-1 之间
scores = []
for doc in rag_docs:
if isinstance(doc, dict):
score = doc.get("rerank_score", 0.0)
elif hasattr(doc, "metadata"):
score = doc.metadata.get("rerank_score", 0.0)
else:
score = 0.0
if score > 0:
scores.append(score)
if scores:
return max(scores) # 使用最高分
return 0.0
async def _get_llm_score(state: MainGraphState) -> float:
"""使用小模型评估检索结果相关性"""
query = state.user_query or ""
rag_context = state.rag_context or ""
try:
llm = get_small_llm_service()
prompt = f"""评估以下检索结果与用户问题的相关性,返回 0.0-1.0 的分数:
- 1.0 = 完全相关,能直接回答问题
- 0.5 = 部分相关,有一定参考价值
- 0.0 = 完全不相关,无法回答问题
用户问题:{query}
检索结果:{rag_context[:1500]}
只返回一个数字:"""
response = await llm.ainvoke(prompt)
content = response.content.strip()
import re
match = re.search(r'(\d+\.?\d*)', content)
if match:
score = float(match.group(1))
return max(0.0, min(1.0, score))
except Exception as e:
info(f"[RAG Confidence] LLM评估失败: {e}")
return 0.5 # 默认中等置信度
# ========== 置信度判断节点 ==========
def check_rag_confidence(state: MainGraphState) -> str:
"""
根据 RAG 置信度判断下一步
Returns:
"high_confidence" - 高置信度(>=0.6),可直接生成回答
"low_confidence" - 低置信度(<0.6),需要联网搜索
"no_rag" - 无检索结果,需要联网搜索
"""
rag_attempts = getattr(state, 'rag_attempts', 0)
rag_confidence = getattr(state, 'rag_confidence', 0.0)
info(f"[Confidence Check] rag_attempts={rag_attempts}, rag_confidence={rag_confidence:.2f}")
# 情况1: 没有检索结果
if not getattr(state, 'rag_retrieved', False) or not state.rag_context:
info("[Confidence Check] 无检索结果,走联网")
return "no_rag"
# 情况2: 置信度低于阈值
if rag_confidence < RAG_CONFIDENCE_THRESHOLD:
if rag_attempts >= 2:
info(f"[Confidence Check] 置信度={rag_confidence:.2f}<{RAG_CONFIDENCE_THRESHOLD}且RAG尝试{rag_attempts}次,走联网")
return "low_confidence"
else:
info(f"[Confidence Check] 置信度={rag_confidence:.2f}<{RAG_CONFIDENCE_THRESHOLD}可再尝试RAG一次")
return "retry_rag"
# 情况3: 高置信度
info(f"[Confidence Check] 高置信度={rag_confidence:.2f}>={RAG_CONFIDENCE_THRESHOLD},直接生成回答")
return "high_confidence"
# ========== 导出 ========== # ========== 导出 ==========
__all__ = [ __all__ = [
"rag_retrieve_node", "rag_retrieve_node",
"rag_re_retrieve_node", "check_rag_confidence",
"RAG_CONFIDENCE_THRESHOLD",
] ]

View File

@@ -7,9 +7,9 @@ from typing import Optional
from datetime import datetime from datetime import datetime
from langchain_core.runnables.config import RunnableConfig from langchain_core.runnables.config import RunnableConfig
from ...core.intent import react_reason_async, ReasoningResult from backend.app.core.intent import react_reason_async, ReasoningResult
from ...main_graph.state import MainGraphState from ...main_graph.state import MainGraphState
from ...logger import info from backend.app.logger import info
from ._utils import dispatch_custom_event, make_react_event from ._utils import dispatch_custom_event, make_react_event

View File

@@ -9,7 +9,7 @@ from typing import Any, Dict
from ...main_graph.state import MainGraphState from ...main_graph.state import MainGraphState
from ...memory.mem0_client import Mem0Client from ...memory.mem0_client import Mem0Client
from ...utils.logging import log_state_change from ...utils.logging import log_state_change
from ...logger import debug from backend.app.logger import debug
def create_retrieve_memory_node(mem0_client: Mem0Client): def create_retrieve_memory_node(mem0_client: Mem0Client):

View File

@@ -10,9 +10,9 @@
from datetime import datetime from datetime import datetime
from ...core.intent import get_route_by_reasoning, ReasoningAction from backend.app.core.intent import get_route_by_reasoning, ReasoningAction
from ...main_graph.state import MainGraphState from ...main_graph.state import MainGraphState
from ...logger import info from backend.app.logger import info
# ========== 初始化状态节点 ========== # ========== 初始化状态节点 ==========
@@ -97,6 +97,12 @@ def route_by_reasoning(state: MainGraphState) -> str:
} }
target = route_mapping.get(route, "llm_call") target = route_mapping.get(route, "llm_call")
# ========== RAG 次数硬限制 ==========
rag_attempts = getattr(state, 'rag_attempts', 0)
if target == "rag_retrieve" and rag_attempts >= 2:
info(f"[条件路由] RAG已尝试{rag_attempts}次,强制走联网搜索")
target = "web_search"
# ========== 循环防护检测 ========== # ========== 循环防护检测 ==========
# 1. 路由模式检测A→B→A→B 交替) # 1. 路由模式检测A→B→A→B 交替)
if len(previous_actions) >= 4: if len(previous_actions) >= 4:

View File

@@ -9,7 +9,7 @@ from typing import Any, Dict
from ...main_graph.state import MainGraphState from ...main_graph.state import MainGraphState
from ...memory.mem0_client import Mem0Client from ...memory.mem0_client import Mem0Client
from ...utils.logging import log_state_change from ...utils.logging import log_state_change
from ...logger import debug, info, error, warning from backend.app.logger import debug, info, error, warning
def create_summarize_node(mem0_client: Mem0Client): def create_summarize_node(mem0_client: Mem0Client):

View File

@@ -11,7 +11,7 @@ from ...main_graph.config import get_stream_writer
# 本地模块 # 本地模块
from ...main_graph.state import MainGraphState from ...main_graph.state import MainGraphState
from ...utils.logging import log_state_change from ...utils.logging import log_state_change
from ...logger import debug, info from backend.app.logger import debug, info
def create_tool_call_node(tools_by_name: Dict[str, Any]): def create_tool_call_node(tools_by_name: Dict[str, Any]):
""" """

View File

@@ -7,7 +7,7 @@ from datetime import datetime
from langchain_core.runnables.config import RunnableConfig from langchain_core.runnables.config import RunnableConfig
from ...main_graph.state import MainGraphState, ErrorRecord, ErrorSeverity from ...main_graph.state import MainGraphState, ErrorRecord, ErrorSeverity
from ...logger import info from backend.app.logger import info
async def web_search_node(state: MainGraphState, config: Optional[RunnableConfig] = None) -> MainGraphState: async def web_search_node(state: MainGraphState, config: Optional[RunnableConfig] = None) -> MainGraphState:

View File

@@ -75,6 +75,8 @@ class MainGraphState:
rag_context: str = "" rag_context: str = ""
rag_retrieved: bool = False rag_retrieved: bool = False
rag_docs: List[Dict[str, Any]] = field(default_factory=list) rag_docs: List[Dict[str, Any]] = field(default_factory=list)
rag_confidence: float = 0.0 # RAG 检索置信度 (0.0-1.0)
rag_attempts: int = 0 # RAG 检索次数统计
# ========== 联网搜索相关字段 ========== # ========== 联网搜索相关字段 ==========
web_search_results: List[str] = field(default_factory=list) web_search_results: List[str] = field(default_factory=list)

View File

@@ -8,7 +8,7 @@ from datetime import datetime
from langchain_core.runnables.config import RunnableConfig from langchain_core.runnables.config import RunnableConfig
from .state import MainGraphState, ErrorRecord, ErrorSeverity from .state import MainGraphState, ErrorRecord, ErrorSeverity
from ..logger import info from backend.app.logger import info
def wrap_subgraph_for_error_handling(subgraph, name: str): def wrap_subgraph_for_error_handling(subgraph, name: str):

View File

@@ -2,7 +2,7 @@
from ...rag.tools import create_rag_tool from ...rag.tools import create_rag_tool
from ...rag.retriever import create_parent_hybrid_retriever from ...rag.retriever import create_parent_hybrid_retriever
from ...model_services import get_embedding_service from ...model_services import get_embedding_service
from ...logger import info, warning from backend.app.logger import info, warning
import sys import sys
# 全局 RAG 工具 # 全局 RAG 工具

View File

@@ -9,7 +9,7 @@ from typing import Optional, List
from mem0 import AsyncMemory from mem0 import AsyncMemory
from ..config import ( from backend.app.config import (
LLM_API_KEY, LLM_API_KEY,
ZHIPUAI_API_KEY, ZHIPUAI_API_KEY,
VLLM_BASE_URL, VLLM_BASE_URL,
@@ -21,7 +21,7 @@ from ..config import (
ZHIPU_EMBEDDING_MODEL, ZHIPU_EMBEDDING_MODEL,
ZHIPU_API_BASE, ZHIPU_API_BASE,
) )
from ..logger import info, warning, error from backend.app.logger import info, warning, error
from ..model_services import get_embedding_service from ..model_services import get_embedding_service
from ..model_services.chat_services import get_chat_service from ..model_services.chat_services import get_chat_service

View File

@@ -23,7 +23,7 @@ from .base import (
FallbackServiceChain, FallbackServiceChain,
SingletonServiceManager SingletonServiceManager
) )
from ..config import ( from backend.app.config import (
VLLM_BASE_URL, VLLM_BASE_URL,
LLM_API_KEY, LLM_API_KEY,
ZHIPUAI_API_KEY, ZHIPUAI_API_KEY,
@@ -203,7 +203,7 @@ class LocalSmallModelProvider(BaseServiceProvider[BaseChatModel]):
""" """
def __init__(self, model: str = None): def __init__(self, model: str = None):
from ..config import SMALL_LOCAL_MODEL_NAME, SMALL_VLLM_BASE_URL, SMALL_LLM_API_KEY from backend.app.config import SMALL_LOCAL_MODEL_NAME, SMALL_VLLM_BASE_URL, SMALL_LLM_API_KEY
super().__init__("local_small") super().__init__("local_small")
self._model = model or SMALL_LOCAL_MODEL_NAME self._model = model or SMALL_LOCAL_MODEL_NAME
self._base_url = SMALL_VLLM_BASE_URL self._base_url = SMALL_VLLM_BASE_URL
@@ -242,7 +242,7 @@ class DeepSeekSmallModelProvider(BaseServiceProvider[BaseChatModel]):
""" """
def __init__(self, model: str = None): def __init__(self, model: str = None):
from ..config import SMALL_DEEPSEEK_MODEL, SMALL_DEEPSEEK_API_KEY, SMALL_DEEPSEEK_API_BASE from backend.app.config import SMALL_DEEPSEEK_MODEL, SMALL_DEEPSEEK_API_KEY, SMALL_DEEPSEEK_API_BASE
super().__init__("deepseek_small") super().__init__("deepseek_small")
self._model = model or SMALL_DEEPSEEK_MODEL self._model = model or SMALL_DEEPSEEK_MODEL
self._api_key = SMALL_DEEPSEEK_API_KEY self._api_key = SMALL_DEEPSEEK_API_KEY

View File

@@ -21,7 +21,7 @@ from .base import (
FallbackServiceChain, FallbackServiceChain,
SingletonServiceManager SingletonServiceManager
) )
from ..config import ( from backend.app.config import (
LLAMACPP_EMBEDDING_URL, LLAMACPP_EMBEDDING_URL,
LLAMACPP_API_KEY, LLAMACPP_API_KEY,
ZHIPUAI_API_KEY, ZHIPUAI_API_KEY,

View File

@@ -27,7 +27,7 @@ from .base import (
FallbackServiceChain, FallbackServiceChain,
SingletonServiceManager SingletonServiceManager
) )
from ..config import ( from backend.app.config import (
LLAMACPP_RERANKER_URL, LLAMACPP_RERANKER_URL,
LLAMACPP_API_KEY, LLAMACPP_API_KEY,
ZHIPUAI_API_KEY, ZHIPUAI_API_KEY,

View File

@@ -81,11 +81,17 @@ class RAGPipeline:
return await self.retriever.ainvoke(query) return await self.retriever.ainvoke(query)
async def _get_parents(self, child_docs: List[Document]) -> List[Document]: async def _get_parents(self, child_docs: List[Document]) -> List[Document]:
parent_map = {} # 收集 parent_id 和对应的分数
parent_map = {} # parent_id -> (embedding_score, rerank_score)
for doc in child_docs: for doc in child_docs:
pid = doc.metadata.get("parent_id") pid = doc.metadata.get("parent_id")
if pid and pid not in parent_map: if pid and pid not in parent_map:
parent_map[pid] = doc.metadata.get("score", 0.0) # embedding 分数
embedding_score = doc.metadata.get("score", 0.0)
# rerank 分数(如果有的话)
rerank_score = doc.metadata.get("rerank_score", 0.0)
parent_map[pid] = (embedding_score, rerank_score)
if not parent_map: if not parent_map:
logger.warning("[Pipeline] 未找到 parent_id返回子文档") logger.warning("[Pipeline] 未找到 parent_id返回子文档")
@@ -94,10 +100,19 @@ class RAGPipeline:
try: try:
from backend.rag_core import create_docstore from backend.rag_core import create_docstore
docstore, _ = create_docstore() docstore, _ = create_docstore()
# 同步获取(异步版本不存在)
parent_docs = docstore.mget(list(parent_map.keys())) parent_docs = docstore.mget(list(parent_map.keys()))
parent_map2 = {d.metadata.get("id"): d for d in parent_docs if d}
result = [(parent_map2[pid], score) for pid, score in parent_map.items() if pid in parent_map2] # 构建结果,保持分数信息
result = []
for doc in parent_docs:
if doc:
pid = doc.metadata.get("id")
scores = parent_map.get(pid, (0.0, 0.0))
# 将分数添加到 metadata 中
doc.metadata["embedding_score"] = scores[0]
doc.metadata["rerank_score"] = scores[1]
result.append((doc, scores[0] + scores[1] * 2)) # 综合分数rerank 权重更高
result.sort(key=lambda x: x[1], reverse=True) result.sort(key=lambda x: x[1], reverse=True)
docs = [d for d, _ in result] docs = [d for d, _ in result]
logger.info(f"[Pipeline] 获取到 {len(docs)} 个父文档") logger.info(f"[Pipeline] 获取到 {len(docs)} 个父文档")

View File

@@ -49,44 +49,38 @@ class DocumentReranker:
top_n: 返回前 N 个结果 top_n: 返回前 N 个结果
Returns: Returns:
List[Document]: 排序后的文档列表 List[Document]: 排序后的文档列表,每个文档的 metadata 中包含 rerank_score
""" """
if not documents: if not documents:
return [] return []
try: try:
# 1. 从 Document 提取内容(业务逻辑) # 1. 从 Document 提取内容
doc_contents = [doc.page_content for doc in documents] doc_contents = [doc.page_content for doc in documents]
logger.info(f"[Rerank] 收到 {len(documents)} 个文档待重排, query={query[:50]}") logger.info(f"[Rerank] 收到 {len(documents)} 个文档待重排")
total_chars = sum(len(c) for c in doc_contents)
logger.info(f"[Rerank] 各文档长度: {[len(c) for c in doc_contents]}, 总字符数: {total_chars}")
# 粗略估算 tokens (中文约 0.75 tokens/字符)
estimated_tokens = int(total_chars * 0.75)
logger.info(f"[Rerank] 估算总 tokens: ~{estimated_tokens} (假设中文)")
# 2. 调用服务计算得分 # 2. 调用重排服务计算得分
logger.info(f"[Rerank] 正在调用 rerank service: {type(self._rerank_service).__name__}")
scores = self._rerank_service.compute_scores(query, doc_contents) scores = self._rerank_service.compute_scores(query, doc_contents)
logger.info(f"[Rerank] 获取到 {len(scores)} 个得分: {scores}") logger.info(f"[Rerank] 获取到 {len(scores)} 个得分")
# 3. 根据得分排序(业务逻辑) # 3. 构建 (文档, 分数) 对并排序
doc_score_pairs = list(zip(documents, scores)) doc_score_pairs = list(zip(documents, scores))
doc_score_pairs_sorted = sorted(doc_score_pairs, key=lambda x: x[1], reverse=True) doc_score_pairs_sorted = sorted(doc_score_pairs, key=lambda x: x[1], reverse=True)
logger.info(f"[Rerank] 排序后的结果:") # 4. 取 top_n并添加 rerank_score 到 metadata
for i, (doc, score) in enumerate(doc_score_pairs_sorted): top_docs = []
logger.info(f" [{i}] score={score:.4f}, content={doc.page_content[:80]}...") for doc, score in doc_score_pairs_sorted[:top_n]:
# 创建新文档,添加 rerank_score
# 4. 取 top_n new_doc = Document(
top_docs = [pair[0] for pair in doc_score_pairs_sorted[:top_n]] page_content=doc.page_content,
metadata={**doc.metadata, "rerank_score": score}
)
top_docs.append(new_doc)
return top_docs return top_docs
except Exception as e: except Exception as e:
logger.warning(f"重排过程出错,返回原始前 {top_n}结果: {e}") logger.warning(f"[Rerank] 重排失败,返回原始结果: {e}")
logger.warning(f"[Rerank] 异常详情: {type(e).__name__}: {e}")
import traceback
logger.warning(f"[Rerank] 堆栈: {traceback.format_exc()}")
return documents[:top_n] return documents[:top_n]

View File

@@ -22,7 +22,7 @@ from pydantic import Field, PrivateAttr
from backend.rag_core import QdrantHybridStore, get_sparse_embedder, create_docstore from backend.rag_core import QdrantHybridStore, get_sparse_embedder, create_docstore
from backend.rag_core.client import create_async_qdrant_client from backend.rag_core.client import create_async_qdrant_client
from ..model_services import get_embedding_service from ..model_services import get_embedding_service
from ..logger import info, warning, debug from backend.app.logger import info, warning, debug
# 模块级常量 # 模块级常量

View File

@@ -8,7 +8,7 @@ from typing import Dict, Any
from datetime import datetime from datetime import datetime
# 公共工具 # 公共工具
from ...core import MarkdownFormatter from backend.app.core import MarkdownFormatter
from .state import ContactState from .state import ContactState
from .api_client import ContactAPIClient from .api_client import ContactAPIClient

View File

@@ -8,7 +8,7 @@ from datetime import datetime
import random import random
# 公共工具 # 公共工具
from ...core import ( from backend.app.core import (
MarkdownFormatter MarkdownFormatter
) )

View File

@@ -7,7 +7,7 @@ from typing import Dict, Any
from datetime import datetime from datetime import datetime
# 公共工具 # 公共工具
from ...core import MarkdownFormatter from backend.app.core import MarkdownFormatter
from .state import ( from .state import (
NewsAnalysisState, NewsAnalysisState,

View File

@@ -3,8 +3,8 @@ LangGraph 节点日志工具模块
提供状态流转追踪和 LLM 输入输出打印功能 提供状态流转追踪和 LLM 输入输出打印功能
""" """
from ..config import ENABLE_GRAPH_TRACE from backend.app.config import ENABLE_GRAPH_TRACE
from ..logger import debug, info from backend.app.logger import debug, info
from ..main_graph.state import MainGraphState from ..main_graph.state import MainGraphState

View File

@@ -69,7 +69,7 @@ def cleanup(signum, frame):
for i, proc in enumerate(processes): for i, proc in enumerate(processes):
if proc.poll() is None: # 进程还在运行 if proc.poll() is None: # 进程还在运行
proc.terminate() proc.terminate()
proc.wait(timeout=5) proc.wait(timeout=1)
print(f"✓ 服务 {i+1} 已停止") print(f"✓ 服务 {i+1} 已停止")
sys.exit(0) sys.exit(0)