diff --git a/README.md b/README.md index 00d2818..488a354 100644 --- a/README.md +++ b/README.md @@ -1481,7 +1481,7 @@ mkdir backend/app/subgraphs/my_subgraph 2. **创建状态定义 (state.py)** ```python from typing_extensions import TypedDict -from ..core.state_base import BaseSubgraphState +from backend.app.core.state_base import BaseSubgraphState class MySubgraphState(BaseSubgraphState): \"\"\" diff --git a/backend/app/agent/agent_service.py b/backend/app/agent/agent_service.py index 16ff75a..ad02087 100644 --- a/backend/app/agent/agent_service.py +++ b/backend/app/agent/agent_service.py @@ -16,8 +16,8 @@ from ..main_graph.main_graph_builder import build_react_main_graph from ..main_graph.tools.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME from ..main_graph.config import set_stream_writer from ..main_graph.utils.rag_initializer import init_rag_tool -from ..core.intent_classifier import get_intent_classifier -from ..logger import debug, info, warning, error +from backend.app.core.intent_classifier import get_intent_classifier +from backend.app.logger import debug, info, warning, error from ..main_graph.state import MainGraphState, CurrentAction diff --git a/backend/app/agent/history.py b/backend/app/agent/history.py index a87f772..cac1180 100644 --- a/backend/app/agent/history.py +++ b/backend/app/agent/history.py @@ -4,7 +4,7 @@ """ from typing import List, Dict, Any -from ..logger import error # 保持兼容,或者替换为 logger +from backend.app.logger import error # 保持兼容,或者替换为 logger class ThreadHistoryService: """线程历史查询服务""" diff --git a/backend/app/agent/prompts.py b/backend/app/agent/prompts.py index 8fc0e98..42024c5 100644 --- a/backend/app/agent/prompts.py +++ b/backend/app/agent/prompts.py @@ -3,9 +3,10 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder def create_system_prompt(tools: list = None) -> ChatPromptTemplate: """ - 创建系统提示模板,可选择动态注入工具描述。 + 创建系统提示模板,整合多子系统能力、检索策略与回答规范。 """ - tools_section = "" + # 构造工具描述 + tools_section = "无可用工具" if tools: tool_descs = [] for tool in tools: @@ -14,27 +15,46 @@ def create_system_prompt(tools: list = None) -> ChatPromptTemplate: tool_descs.append(f"- {name}: {desc}") tools_section = "\n".join(tool_descs) - system_template = ( - "你是一个智能助手,具有三个专业子系统和RAG检索能力,请使用中文交流。\n\n" - "【核心功能】\n" - "1. 📚 词典/翻译子系统 - 查询单词、翻译文本、提取术语、每日一词\n" - "2. 📰 资讯分析子系统 - 查询新闻、分析URL、提取关键词、生成报告\n" - "3. 📇 通讯录子系统 - 查询联系人、添加联系人、管理通讯录\n" - "4. 🔍 RAG检索 - 从知识库中检索相关信息回答问题\n\n" - "【用户背景信息】\n" - "以下是对当前用户的已知信息和长期记忆,你必须优先采纳:\n" - "{memory_context}\n" - "【可用工具与使用规则】\n" - f"{tools_section}\n" - "工具调用时请直接返回所需参数,无需额外说明。\n\n" - "【回答要求(必须遵守)】\n" - "1. 回答必须简洁、直接。\n" - "2. 如果你认为该问题需要进行深入的推理或思考,请务必将你的思维链或推理过程用 `` 和 `` 标签包裹起来,放在回答的最前面。\n" - "3. 优先利用已知用户信息进行个性化回复。\n" - "4. 若无信息可依,礼貌询问或提供通用帮助。" - ) - + # 使用 f-string 将 tools_section 直接嵌入,而 memory_context 用双花括号转义保留为变量 + system_template = f'''你是一个智能助手,具备以下专业子系统和检索能力。请使用中文交流。 + +## 核心功能 +1. 📚 词典/翻译子系统 – 查询单词、翻译文本、提取术语、每日一词 +2. 📰 资讯分析子系统 – 查询新闻、分析URL、提取关键词、生成报告 +3. 📇 通讯录子系统 – 查询联系人、添加联系人、管理通讯录 +4. 🔍 RAG检索 – 从知识库中检索相关信息回答问题 + +## 检索与信息获取策略 +当收到用户问题时,请按以下优先级处理: +1. **RAG 检索(第1次)**:首先尝试从知识库中查找答案。 +2. **ReRAG(第2次优化检索)**:如果第一次检索结果不相关或不充分,可以优化查询后再次进行 RAG 检索。 +3. **联网搜索**:如果两次 RAG 检索后仍无法获得满意答案,必须使用联网搜索获取最新信息。 +**重要约束**: +- 最多进行 **2 次** RAG 检索尝试。 +- 第3次决定获取信息时,必须选择**联网搜索**,禁止无休止的本地检索。 +- 如果已经明确知识库不包含该信息(例如用户询问实时新闻),可以直接进入联网搜索。 + +## 可用工具 +{tools_section} +工具调用时请直接返回所需参数,无需额外说明。 + +## 用户背景信息 +以下是当前用户的已知信息和长期记忆,你应在回答中优先利用这些信息进行个性化回复: +{{memory_context}} +若无相关信息,可礼貌询问或提供通用帮助。 + +## 回答要求(必须严格遵守) +1. **来源标注**:回答开头必须明确标注信息来源,格式如下: + - 使用知识库时:`【知识库:来源描述】` + - 使用联网搜索时:`【联网搜索:来源描述】` + - 若同时用到多个来源,按实际使用顺序标注,例如:`【知识库:三国演义】【联网搜索:百度百科】` +2. **思维链**:如果问题需要深入推理或复杂思考,请将推理过程用 `...` 标签包裹,放在回答最前面(来源标注之前)。 +3. **简洁直接**:回答应重点突出、条理清晰,避免冗长。 +4. **个性化**:结合用户信息进行针对性回复。 +5. **无依据时**:若既无知识库支撑也无联网搜索结果,请如实说明无法回答,并建议用户提供更多信息或尝试其他方式。 +''' + return ChatPromptTemplate.from_messages([ ("system", system_template), MessagesPlaceholder(variable_name="messages") - ]) + ]) \ No newline at end of file diff --git a/backend/app/backend.py b/backend/app/backend.py index c5e23e6..69a48ba 100644 --- a/backend/app/backend.py +++ b/backend/app/backend.py @@ -9,7 +9,7 @@ warnings.filterwarnings("ignore", category=DeprecationWarning, module="websocket warnings.filterwarnings("ignore", category=DeprecationWarning, module="uvicorn.protocols.websockets") import os -from ..config import DB_URI, BACKEND_PORT +from backend.app.config import DB_URI, BACKEND_PORT import uuid import json from contextlib import asynccontextmanager @@ -22,18 +22,18 @@ from pydantic import BaseModel from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver from .agent.agent_service import AIAgentService, create_serde from .agent.history import ThreadHistoryService -from ..core.human_review import ( +from backend.app.core.human_review import ( ReviewManager, InMemoryReviewStore, ReviewStatus, HumanReview ) -from ..subgraphs.contact.api_client import ContactAPIClient -from ..subgraphs.dictionary.api_client import DictionaryAPIClient -from ..subgraphs.news_analysis.api_client import NewsAPIClient +from backend.app.subgraphs.contact.api_client import ContactAPIClient +from backend.app.subgraphs.dictionary.api_client import DictionaryAPIClient +from backend.app.subgraphs.news_analysis.api_client import NewsAPIClient from .db.init_db import init_subgraph_tables from .db.models import ContactRepository, DictionaryRepository, NewsRepository -from ..logger import info, error +from backend.app.logger import info, error @asynccontextmanager async def lifespan(app: FastAPI): diff --git a/backend/app/main_graph/main_graph_builder.py b/backend/app/main_graph/main_graph_builder.py index bde6d0f..7937908 100644 --- a/backend/app/main_graph/main_graph_builder.py +++ b/backend/app/main_graph/main_graph_builder.py @@ -21,15 +21,15 @@ from .nodes.fast_paths import ( fast_tool_node, ) from .nodes.llm_call import create_dynamic_llm_call_node -from .nodes.rag_nodes import rag_retrieve_node +from .nodes.rag_nodes import rag_retrieve_node, check_rag_confidence from .nodes.retrieve_memory import create_retrieve_memory_node from .nodes.memory_trigger import memory_trigger_node, set_mem0_client from .nodes.summarize import create_summarize_node from .nodes.finalize import finalize_node -from ..subgraphs.contact import build_contact_subgraph -from ..subgraphs.dictionary import build_dictionary_subgraph -from ..subgraphs.news_analysis import build_news_analysis_subgraph -from ..logger import info +from backend.app.subgraphs.contact import build_contact_subgraph +from backend.app.subgraphs.dictionary import build_dictionary_subgraph +from backend.app.subgraphs.news_analysis import build_news_analysis_subgraph +from backend.app.logger import info from .subgraph_wrapper import create_subgraph_nodes @@ -198,8 +198,20 @@ def _add_react_loop_edges(graph: StateGraph, subgraph_nodes: Dict[str, Any]) -> } ) + # RAG 检索后的置信度判断分支 + graph.add_conditional_edges( + "rag_retrieve", + check_rag_confidence, + { + "high_confidence": "llm_call", # 高置信度 → 直接生成回答 + "retry_rag": "rag_retrieve", # 低置信度 → 再次检索 + "low_confidence": "web_search", # 两次RAG后仍低 → 联网搜索 + "no_rag": "web_search", # 无结果 → 联网搜索 + } + ) + # 循环边(回到 react_reason) - loop_back_nodes = ["rag_retrieve", "web_search", "handle_error"] + subgraph_names + loop_back_nodes = ["web_search", "handle_error"] + subgraph_names for node_name in loop_back_nodes: graph.add_edge(node_name, "react_reason") diff --git a/backend/app/main_graph/nodes/__init__.py b/backend/app/main_graph/nodes/__init__.py index bd00186..0ff79ae 100644 --- a/backend/app/main_graph/nodes/__init__.py +++ b/backend/app/main_graph/nodes/__init__.py @@ -8,7 +8,7 @@ from .web_search import web_search_node from .error_handling import error_handling_node from .routing import init_state_node, route_by_reasoning, should_summarize from .llm_call import create_dynamic_llm_call_node -from .rag_nodes import rag_retrieve_node, rag_re_retrieve_node +from .rag_nodes import rag_retrieve_node # 记忆节点 from .retrieve_memory import create_retrieve_memory_node diff --git a/backend/app/main_graph/nodes/error_handling.py b/backend/app/main_graph/nodes/error_handling.py index 646afb0..e077fbd 100644 --- a/backend/app/main_graph/nodes/error_handling.py +++ b/backend/app/main_graph/nodes/error_handling.py @@ -3,7 +3,7 @@ """ from ...main_graph.state import MainGraphState, ErrorSeverity -from ...logger import info +from backend.app.logger import info def error_handling_node(state: MainGraphState) -> MainGraphState: diff --git a/backend/app/main_graph/nodes/fast_paths.py b/backend/app/main_graph/nodes/fast_paths.py index 52357f5..6dc7fec 100644 --- a/backend/app/main_graph/nodes/fast_paths.py +++ b/backend/app/main_graph/nodes/fast_paths.py @@ -7,7 +7,7 @@ from typing import Optional from langchain_core.runnables.config import RunnableConfig from ..state import MainGraphState -from ...logger import info, debug +from backend.app.logger import info, debug from ...model_services.chat_services import get_small_llm_service, get_chat_service from .rag_nodes import rag_retrieve_node from ._utils import dispatch_custom_event @@ -113,10 +113,18 @@ async def fast_rag_node(state: MainGraphState, config: Optional[RunnableConfig] def _has_valid_rag_results(state: MainGraphState) -> bool: - """检查 RAG 结果是否有效""" - rag_docs = getattr(state, "rag_docs", []) + """检查 RAG 结果是否有效(基于置信度)""" + from .rag_nodes import RAG_CONFIDENCE_THRESHOLD rag_context = getattr(state, "rag_context", "") - return (rag_docs and len(rag_docs) > 0) or (rag_context and len(rag_context) > 10) + rag_confidence = getattr(state, "rag_confidence", 0.0) + + # 有结果且置信度足够 + has_content = rag_context and len(rag_context) > 0 + has_confidence = rag_confidence >= RAG_CONFIDENCE_THRESHOLD + + info(f"[Fast RAG Check] has_content={has_content}, rag_confidence={rag_confidence:.2f}, threshold={RAG_CONFIDENCE_THRESHOLD}") + + return has_content and has_confidence async def _generate_fast_answer(state: MainGraphState, query: str) -> MainGraphState: diff --git a/backend/app/main_graph/nodes/finalize.py b/backend/app/main_graph/nodes/finalize.py index 4dc859d..784c340 100644 --- a/backend/app/main_graph/nodes/finalize.py +++ b/backend/app/main_graph/nodes/finalize.py @@ -8,7 +8,7 @@ from typing import Any, Dict # 本地模块 from ...main_graph.state import MainGraphState from ...utils.logging import log_state_change -from ...logger import info, warning +from backend.app.logger import info, warning from langchain_core.runnables.config import RunnableConfig diff --git a/backend/app/main_graph/nodes/hybrid_router.py b/backend/app/main_graph/nodes/hybrid_router.py index 729a07c..a18cba9 100644 --- a/backend/app/main_graph/nodes/hybrid_router.py +++ b/backend/app/main_graph/nodes/hybrid_router.py @@ -11,7 +11,7 @@ from datetime import datetime from langchain_core.runnables.config import RunnableConfig from ..state import MainGraphState -from ...logger import info, debug +from backend.app.logger import info, debug from ...model_services.chat_services import get_small_llm_service from ._utils import dispatch_custom_event diff --git a/backend/app/main_graph/nodes/llm_call.py b/backend/app/main_graph/nodes/llm_call.py index fc0bbdb..9ab9fd2 100644 --- a/backend/app/main_graph/nodes/llm_call.py +++ b/backend/app/main_graph/nodes/llm_call.py @@ -12,7 +12,7 @@ from langchain_core.messages import AIMessage from ...main_graph.state import MainGraphState from ...agent.prompts import create_system_prompt from ...utils.logging import log_state_change -from ...logger import debug, info, error +from backend.app.logger import debug, info, error def create_dynamic_llm_call_node(chat_services: Dict[str, BaseChatModel], tools: list): @@ -115,24 +115,7 @@ def create_dynamic_llm_call_node(chat_services: Dict[str, BaseChatModel], tools: ): chunks.append(chunk) - info(f"[llm_call] LLM astream 完成,共收到 {len(chunks)} 个 chunks") - for i, chunk in enumerate(chunks[:10]): # 只打印前10个避免日志过多 - chunk_type = type(chunk).__name__ - chunk_content = getattr(chunk, 'content', '') if hasattr(chunk, 'content') else str(chunk) - # 打印更多属性 - additional_kwargs = getattr(chunk, 'additional_kwargs', {}) or {} - response_metadata = getattr(chunk, 'response_metadata', {}) or {} - # 打印所有属性 - info(f"[llm_call] chunk[{i}] type={chunk_type}") - info(f"[llm_call] chunk[{i}] content长度={len(chunk_content) if chunk_content else 0}, content={repr(chunk_content[:200] if chunk_content else '')}") - info(f"[llm_call] chunk[{i}] additional_kwargs={additional_kwargs}") - info(f"[llm_call] chunk[{i}] response_metadata keys={list(response_metadata.keys()) if response_metadata else []}") - info(f"[llm_call] chunk[{i}] response_metadata={response_metadata}") - # 检查是否有其他可能存储内容的属性 - if hasattr(chunk, 'tool_call_chunks'): - info(f"[llm_call] chunk[{i}] tool_call_chunks={chunk.tool_call_chunks}") - if hasattr(chunk, 'usage_metadata'): - info(f"[llm_call] chunk[{i}] usage_metadata={chunk.usage_metadata}") + info(f"[llm_call] LLM astream 完成,共收到 {len(chunks)} 个 chunks,info:{chunks}") # 将所有 chunk 合并成最终的 AIMessage if chunks: diff --git a/backend/app/main_graph/nodes/memory_trigger.py b/backend/app/main_graph/nodes/memory_trigger.py index 3dd0d73..7d309af 100644 --- a/backend/app/main_graph/nodes/memory_trigger.py +++ b/backend/app/main_graph/nodes/memory_trigger.py @@ -2,7 +2,7 @@ from typing import Any, Dict from langchain_core.runnables.config import RunnableConfig from ...main_graph.state import MainGraphState from ...memory.mem0_client import Mem0Client -from ...logger import info +from backend.app.logger import info # 全局变量,在 GraphBuilder 中注入 diff --git a/backend/app/main_graph/nodes/rag_nodes.py b/backend/app/main_graph/nodes/rag_nodes.py index c04b664..d036cac 100644 --- a/backend/app/main_graph/nodes/rag_nodes.py +++ b/backend/app/main_graph/nodes/rag_nodes.py @@ -1,6 +1,6 @@ """ RAG 检索节点模块 -使用模块级变量管理 RAG 工具 +包含:RAG 检索、置信度判断、重检索等节点 """ import time @@ -11,10 +11,15 @@ from langchain_core.runnables.config import RunnableConfig from ...main_graph.state import MainGraphState, ErrorRecord, ErrorSeverity from ...main_graph.utils.retry_utils import RAG_RETRY_CONFIG -from ...logger import info +from backend.app.logger import info, debug +from ...model_services import get_small_llm_service from ._utils import dispatch_custom_event, make_react_event +# 置信度阈值配置 +RAG_CONFIDENCE_THRESHOLD = 0.6 # 低于此值认为检索不相关 + + def _get_rag_tool() -> Optional[callable]: """获取 RAG 工具""" from backend.app.main_graph.utils.rag_initializer import get_rag_tool @@ -36,43 +41,27 @@ async def _rag_retrieve_core(state: MainGraphState, rag_tool: callable) -> MainG # 调用 RAG 工具 rag_context = await rag_tool.ainvoke(retrieval_query) info(f"[RAG Core] 获取到 rag_context: {type(rag_context)}, 长度={len(rag_context) if rag_context else 0}") - info(f"[RAG Core] ========== RAG 返回的知识内容 ==========") - info(f"{rag_context[:500]}..." if len(rag_context) > 500 else rag_context) - info(f"[RAG Core] ========================================") + # 更新状态 state.rag_context = rag_context - state.rag_docs = [{"source": "rag_retrieval", "content": rag_context}] state.rag_retrieved = True - state.success = True + state.rag_attempts = getattr(state, 'rag_attempts', 0) + 1 state.debug_info["rag_source"] = "tool" - info(f"[RAG Core] state.rag_docs 长度: {len(state.rag_docs)}") return state # ========== RAG 检索节点 ========== async def rag_retrieve_node(state: MainGraphState, config: Optional[RunnableConfig] = None) -> MainGraphState: - """RAG 检索节点:带超时和重试""" + """RAG 检索节点:检索 + 置信度评估""" state.current_phase = "rag_retrieving" start_time = time.time() - last_error = None - # 获取 RAG 工具 rag_tool = _get_rag_tool() - if not rag_tool: - error_record = ErrorRecord( - error_type="RAGRetrievalError", - error_message="RAG 工具未初始化", - severity=ErrorSeverity.WARNING, - source="rag_retrieve_node", - timestamp=datetime.now().isoformat(), - retry_count=0, - max_retries=RAG_RETRY_CONFIG.max_retries, - ) - state.errors.append(error_record) - state.current_error = error_record - state.current_phase = "error_handling" + info("[RAG] RAG 工具未初始化") + state.rag_confidence = 0.0 + state.rag_retrieved = False return state await dispatch_custom_event( @@ -81,99 +70,184 @@ async def rag_retrieve_node(state: MainGraphState, config: Optional[RunnableConf config ) - for attempt in range(RAG_RETRY_CONFIG.max_retries + 1): - try: - result = await _rag_retrieve_core(state, rag_tool) + try: + state = await _rag_retrieve_core(state, rag_tool) - info(f"[RAG] 检索成功,上下文长度: {len(result.rag_context)} 字符") + # 评估置信度 + confidence = await _evaluate_rag_confidence(state) + state.rag_confidence = confidence - state.debug_info["rag_retrieval"] = { - "attempt": attempt + 1, - "success": True, - "time": time.time() - start_time - } + info(f"[RAG] 检索完成,置信度={confidence:.2f},RAG尝试次数={state.rag_attempts}") - state.reasoning_history.append({ - "step": state.reasoning_step, - "action": "RETRIEVE_RAG", - "confidence": 1.0, - "reasoning": "RAG 检索完成", - "timestamp": datetime.now().isoformat() - }) + state.reasoning_history.append({ + "step": state.reasoning_step, + "action": "RETRIEVE_RAG", + "confidence": confidence, + "reasoning": f"RAG 检索完成,置信度={confidence:.2f}", + "timestamp": datetime.now().isoformat() + }) - doc_count = len(result.rag_docs) if result.rag_docs else 0 - await dispatch_custom_event( - "react_reasoning", - make_react_event(state.reasoning_step, "rag_retrieve_complete", 1.0, - f"RAG 检索完成,找到 {doc_count} 条相关文档"), - config - ) + await dispatch_custom_event( + "react_reasoning", + make_react_event(state.reasoning_step, "rag_retrieve_complete", confidence, + f"RAG 检索完成,置信度={confidence:.2f}"), + config + ) - return result - - except Exception as e: - last_error = e - - if attempt >= RAG_RETRY_CONFIG.max_retries: - break - - await dispatch_custom_event( - "react_reasoning", - make_react_event(state.reasoning_step, "rag_retrieve_retry", 1.0, - f"RAG 检索失败,第 {attempt + 1} 次重试..."), - config - ) - - delay = RAG_RETRY_CONFIG.base_delay * (2 ** attempt) - await asyncio.sleep(min(delay, RAG_RETRY_CONFIG.max_delay)) - - # 失败记录 - state.reasoning_history.append({ - "step": state.reasoning_step, - "action": "RETRIEVE_RAG", - "confidence": 0.0, - "reasoning": f"RAG 检索失败: {str(last_error) if last_error else '超时'}", - "timestamp": datetime.now().isoformat() - }) - - error_record = ErrorRecord( - error_type="RAGRetrievalError", - error_message=str(last_error) if last_error else "RAG 检索超时", - severity=ErrorSeverity.WARNING, - source="rag_retrieve_node", - timestamp=datetime.now().isoformat(), - retry_count=RAG_RETRY_CONFIG.max_retries, - max_retries=RAG_RETRY_CONFIG.max_retries, - ) - - state.errors.append(error_record) - state.current_error = error_record - state.current_phase = "error_handling" - - await dispatch_custom_event( - "react_reasoning", - make_react_event(state.reasoning_step, "rag_retrieve_error", 1.0, - f"RAG 检索失败: {str(last_error)}"), - config - ) + except Exception as e: + info(f"[RAG] 检索失败: {e}") + state.rag_confidence = 0.0 + state.rag_retrieved = False return state -async def rag_re_retrieve_node(state: MainGraphState, config: Optional[RunnableConfig] = None) -> MainGraphState: - """重新检索节点""" - state.current_phase = "rag_re_retrieving" +async def _evaluate_rag_confidence(state: MainGraphState) -> float: + """评估 RAG 检索结果置信度(综合向量相似度 + 重排分数 + 小模型判断)""" + query = state.user_query or "" + rag_context = state.rag_context or "" - state.debug_info["rag_re_retrieve"] = { - "original_retrieved": state.rag_retrieved, - "original_docs_count": len(state.rag_docs) - } + if not rag_context: + return 0.0 - return await rag_retrieve_node(state, config) + # 方式1: 向量相似度(从 rag_docs 中获取) + embedding_score = _get_embedding_similarity(state, query) + info(f"[RAG Confidence] 向量相似度={embedding_score:.3f}") + + # 方式2: 重排序分数(从 rag_docs 中获取) + rerank_score = _get_rerank_score(state) + info(f"[RAG Confidence] 重排分数={rerank_score:.3f}") + + # 方式3: 小模型判断 + llm_score = await _get_llm_score(state) + info(f"[RAG Confidence] LLM评估={llm_score:.3f}") + + # 综合得分(加权平均) + # 向量相似度权重 0.3,重排权重 0.3,LLM 权重 0.4 + final_score = embedding_score * 0.3 + rerank_score * 0.3 + llm_score * 0.4 + info(f"[RAG Confidence] 综合置信度={final_score:.3f} (embedding={embedding_score:.3f}*0.3 + rerank={rerank_score:.3f}*0.3 + llm={llm_score:.3f}*0.4)") + + return final_score + + +def _get_embedding_similarity(state: MainGraphState) -> float: + """从 rag_docs 中获取向量相似度分数""" + rag_docs = getattr(state, "rag_docs", []) + + # 如果有多个文档,取最高分 + scores = [] + for doc in rag_docs: + if isinstance(doc, dict): + score = doc.get("score", 0.0) + # 向量相似度通常在 0-1 之间,RRF 分数可能更高 + # 归一化到 0-1 + if score > 1.0: + score = min(score / 10.0, 1.0) # 假设 max 约 10 + scores.append(score) + elif hasattr(doc, "metadata"): + score = doc.metadata.get("score", 0.0) + if score > 1.0: + score = min(score / 10.0, 1.0) + scores.append(score) + + if scores: + # 取平均或最高分 + return max(scores) # 使用最高分更准确 + return 0.0 + + +def _get_rerank_score(state: MainGraphState) -> float: + """从 rag_docs 中获取重排序分数""" + rag_docs = getattr(state, "rag_docs", []) + + # 重排分数通常在 0-1 之间 + scores = [] + for doc in rag_docs: + if isinstance(doc, dict): + score = doc.get("rerank_score", 0.0) + elif hasattr(doc, "metadata"): + score = doc.metadata.get("rerank_score", 0.0) + else: + score = 0.0 + + if score > 0: + scores.append(score) + + if scores: + return max(scores) # 使用最高分 + return 0.0 + + +async def _get_llm_score(state: MainGraphState) -> float: + """使用小模型评估检索结果相关性""" + query = state.user_query or "" + rag_context = state.rag_context or "" + + try: + llm = get_small_llm_service() + prompt = f"""评估以下检索结果与用户问题的相关性,返回 0.0-1.0 的分数: +- 1.0 = 完全相关,能直接回答问题 +- 0.5 = 部分相关,有一定参考价值 +- 0.0 = 完全不相关,无法回答问题 + +用户问题:{query} + +检索结果:{rag_context[:1500]} + +只返回一个数字:""" + + response = await llm.ainvoke(prompt) + content = response.content.strip() + + import re + match = re.search(r'(\d+\.?\d*)', content) + if match: + score = float(match.group(1)) + return max(0.0, min(1.0, score)) + + except Exception as e: + info(f"[RAG Confidence] LLM评估失败: {e}") + + return 0.5 # 默认中等置信度 + + +# ========== 置信度判断节点 ========== +def check_rag_confidence(state: MainGraphState) -> str: + """ + 根据 RAG 置信度判断下一步 + + Returns: + "high_confidence" - 高置信度(>=0.6),可直接生成回答 + "low_confidence" - 低置信度(<0.6),需要联网搜索 + "no_rag" - 无检索结果,需要联网搜索 + """ + rag_attempts = getattr(state, 'rag_attempts', 0) + rag_confidence = getattr(state, 'rag_confidence', 0.0) + + info(f"[Confidence Check] rag_attempts={rag_attempts}, rag_confidence={rag_confidence:.2f}") + + # 情况1: 没有检索结果 + if not getattr(state, 'rag_retrieved', False) or not state.rag_context: + info("[Confidence Check] 无检索结果,走联网") + return "no_rag" + + # 情况2: 置信度低于阈值 + if rag_confidence < RAG_CONFIDENCE_THRESHOLD: + if rag_attempts >= 2: + info(f"[Confidence Check] 置信度={rag_confidence:.2f}<{RAG_CONFIDENCE_THRESHOLD},且RAG尝试{rag_attempts}次,走联网") + return "low_confidence" + else: + info(f"[Confidence Check] 置信度={rag_confidence:.2f}<{RAG_CONFIDENCE_THRESHOLD},可再尝试RAG一次") + return "retry_rag" + + # 情况3: 高置信度 + info(f"[Confidence Check] 高置信度={rag_confidence:.2f}>={RAG_CONFIDENCE_THRESHOLD},直接生成回答") + return "high_confidence" # ========== 导出 ========== __all__ = [ "rag_retrieve_node", - "rag_re_retrieve_node", + "check_rag_confidence", + "RAG_CONFIDENCE_THRESHOLD", ] diff --git a/backend/app/main_graph/nodes/reasoning.py b/backend/app/main_graph/nodes/reasoning.py index 09ef0a3..b8c81e1 100644 --- a/backend/app/main_graph/nodes/reasoning.py +++ b/backend/app/main_graph/nodes/reasoning.py @@ -7,9 +7,9 @@ from typing import Optional from datetime import datetime from langchain_core.runnables.config import RunnableConfig -from ...core.intent import react_reason_async, ReasoningResult +from backend.app.core.intent import react_reason_async, ReasoningResult from ...main_graph.state import MainGraphState -from ...logger import info +from backend.app.logger import info from ._utils import dispatch_custom_event, make_react_event diff --git a/backend/app/main_graph/nodes/retrieve_memory.py b/backend/app/main_graph/nodes/retrieve_memory.py index 1a48855..3837796 100644 --- a/backend/app/main_graph/nodes/retrieve_memory.py +++ b/backend/app/main_graph/nodes/retrieve_memory.py @@ -9,7 +9,7 @@ from typing import Any, Dict from ...main_graph.state import MainGraphState from ...memory.mem0_client import Mem0Client from ...utils.logging import log_state_change -from ...logger import debug +from backend.app.logger import debug def create_retrieve_memory_node(mem0_client: Mem0Client): diff --git a/backend/app/main_graph/nodes/routing.py b/backend/app/main_graph/nodes/routing.py index d9b62dd..36d18b4 100644 --- a/backend/app/main_graph/nodes/routing.py +++ b/backend/app/main_graph/nodes/routing.py @@ -10,9 +10,9 @@ from datetime import datetime -from ...core.intent import get_route_by_reasoning, ReasoningAction +from backend.app.core.intent import get_route_by_reasoning, ReasoningAction from ...main_graph.state import MainGraphState -from ...logger import info +from backend.app.logger import info # ========== 初始化状态节点 ========== @@ -97,6 +97,12 @@ def route_by_reasoning(state: MainGraphState) -> str: } target = route_mapping.get(route, "llm_call") + # ========== RAG 次数硬限制 ========== + rag_attempts = getattr(state, 'rag_attempts', 0) + if target == "rag_retrieve" and rag_attempts >= 2: + info(f"[条件路由] RAG已尝试{rag_attempts}次,强制走联网搜索") + target = "web_search" + # ========== 循环防护检测 ========== # 1. 路由模式检测(A→B→A→B 交替) if len(previous_actions) >= 4: diff --git a/backend/app/main_graph/nodes/summarize.py b/backend/app/main_graph/nodes/summarize.py index 1969d10..3836de7 100644 --- a/backend/app/main_graph/nodes/summarize.py +++ b/backend/app/main_graph/nodes/summarize.py @@ -9,7 +9,7 @@ from typing import Any, Dict from ...main_graph.state import MainGraphState from ...memory.mem0_client import Mem0Client from ...utils.logging import log_state_change -from ...logger import debug, info, error, warning +from backend.app.logger import debug, info, error, warning def create_summarize_node(mem0_client: Mem0Client): diff --git a/backend/app/main_graph/nodes/tool_call.py b/backend/app/main_graph/nodes/tool_call.py index 315f119..b990615 100644 --- a/backend/app/main_graph/nodes/tool_call.py +++ b/backend/app/main_graph/nodes/tool_call.py @@ -11,7 +11,7 @@ from ...main_graph.config import get_stream_writer # 本地模块 from ...main_graph.state import MainGraphState from ...utils.logging import log_state_change -from ...logger import debug, info +from backend.app.logger import debug, info def create_tool_call_node(tools_by_name: Dict[str, Any]): """ diff --git a/backend/app/main_graph/nodes/web_search.py b/backend/app/main_graph/nodes/web_search.py index f9aa62e..d02d9da 100644 --- a/backend/app/main_graph/nodes/web_search.py +++ b/backend/app/main_graph/nodes/web_search.py @@ -7,7 +7,7 @@ from datetime import datetime from langchain_core.runnables.config import RunnableConfig from ...main_graph.state import MainGraphState, ErrorRecord, ErrorSeverity -from ...logger import info +from backend.app.logger import info async def web_search_node(state: MainGraphState, config: Optional[RunnableConfig] = None) -> MainGraphState: diff --git a/backend/app/main_graph/state.py b/backend/app/main_graph/state.py index 92267df..145cb14 100644 --- a/backend/app/main_graph/state.py +++ b/backend/app/main_graph/state.py @@ -75,6 +75,8 @@ class MainGraphState: rag_context: str = "" rag_retrieved: bool = False rag_docs: List[Dict[str, Any]] = field(default_factory=list) + rag_confidence: float = 0.0 # RAG 检索置信度 (0.0-1.0) + rag_attempts: int = 0 # RAG 检索次数统计 # ========== 联网搜索相关字段 ========== web_search_results: List[str] = field(default_factory=list) diff --git a/backend/app/main_graph/subgraph_wrapper.py b/backend/app/main_graph/subgraph_wrapper.py index f01cec1..91c08f5 100644 --- a/backend/app/main_graph/subgraph_wrapper.py +++ b/backend/app/main_graph/subgraph_wrapper.py @@ -8,7 +8,7 @@ from datetime import datetime from langchain_core.runnables.config import RunnableConfig from .state import MainGraphState, ErrorRecord, ErrorSeverity -from ..logger import info +from backend.app.logger import info def wrap_subgraph_for_error_handling(subgraph, name: str): diff --git a/backend/app/main_graph/utils/rag_initializer.py b/backend/app/main_graph/utils/rag_initializer.py index 62c8d7b..dbd2233 100644 --- a/backend/app/main_graph/utils/rag_initializer.py +++ b/backend/app/main_graph/utils/rag_initializer.py @@ -2,7 +2,7 @@ from ...rag.tools import create_rag_tool from ...rag.retriever import create_parent_hybrid_retriever from ...model_services import get_embedding_service -from ...logger import info, warning +from backend.app.logger import info, warning import sys # 全局 RAG 工具 diff --git a/backend/app/memory/mem0_client.py b/backend/app/memory/mem0_client.py index 948be98..c9960e7 100644 --- a/backend/app/memory/mem0_client.py +++ b/backend/app/memory/mem0_client.py @@ -9,7 +9,7 @@ from typing import Optional, List from mem0 import AsyncMemory -from ..config import ( +from backend.app.config import ( LLM_API_KEY, ZHIPUAI_API_KEY, VLLM_BASE_URL, @@ -21,7 +21,7 @@ from ..config import ( ZHIPU_EMBEDDING_MODEL, ZHIPU_API_BASE, ) -from ..logger import info, warning, error +from backend.app.logger import info, warning, error from ..model_services import get_embedding_service from ..model_services.chat_services import get_chat_service diff --git a/backend/app/model_services/chat_services.py b/backend/app/model_services/chat_services.py index 4843ce8..349712f 100644 --- a/backend/app/model_services/chat_services.py +++ b/backend/app/model_services/chat_services.py @@ -23,7 +23,7 @@ from .base import ( FallbackServiceChain, SingletonServiceManager ) -from ..config import ( +from backend.app.config import ( VLLM_BASE_URL, LLM_API_KEY, ZHIPUAI_API_KEY, @@ -203,7 +203,7 @@ class LocalSmallModelProvider(BaseServiceProvider[BaseChatModel]): """ def __init__(self, model: str = None): - from ..config import SMALL_LOCAL_MODEL_NAME, SMALL_VLLM_BASE_URL, SMALL_LLM_API_KEY + from backend.app.config import SMALL_LOCAL_MODEL_NAME, SMALL_VLLM_BASE_URL, SMALL_LLM_API_KEY super().__init__("local_small") self._model = model or SMALL_LOCAL_MODEL_NAME self._base_url = SMALL_VLLM_BASE_URL @@ -242,7 +242,7 @@ class DeepSeekSmallModelProvider(BaseServiceProvider[BaseChatModel]): """ def __init__(self, model: str = None): - from ..config import SMALL_DEEPSEEK_MODEL, SMALL_DEEPSEEK_API_KEY, SMALL_DEEPSEEK_API_BASE + from backend.app.config import SMALL_DEEPSEEK_MODEL, SMALL_DEEPSEEK_API_KEY, SMALL_DEEPSEEK_API_BASE super().__init__("deepseek_small") self._model = model or SMALL_DEEPSEEK_MODEL self._api_key = SMALL_DEEPSEEK_API_KEY diff --git a/backend/app/model_services/embedding_services.py b/backend/app/model_services/embedding_services.py index f371571..d8541b8 100644 --- a/backend/app/model_services/embedding_services.py +++ b/backend/app/model_services/embedding_services.py @@ -21,7 +21,7 @@ from .base import ( FallbackServiceChain, SingletonServiceManager ) -from ..config import ( +from backend.app.config import ( LLAMACPP_EMBEDDING_URL, LLAMACPP_API_KEY, ZHIPUAI_API_KEY, diff --git a/backend/app/model_services/rerank_services.py b/backend/app/model_services/rerank_services.py index 475995b..f170698 100644 --- a/backend/app/model_services/rerank_services.py +++ b/backend/app/model_services/rerank_services.py @@ -27,7 +27,7 @@ from .base import ( FallbackServiceChain, SingletonServiceManager ) -from ..config import ( +from backend.app.config import ( LLAMACPP_RERANKER_URL, LLAMACPP_API_KEY, ZHIPUAI_API_KEY, diff --git a/backend/app/rag/pipeline.py b/backend/app/rag/pipeline.py index 8fd67d2..0d15fea 100644 --- a/backend/app/rag/pipeline.py +++ b/backend/app/rag/pipeline.py @@ -81,11 +81,17 @@ class RAGPipeline: return await self.retriever.ainvoke(query) async def _get_parents(self, child_docs: List[Document]) -> List[Document]: - parent_map = {} + # 收集 parent_id 和对应的分数 + parent_map = {} # parent_id -> (embedding_score, rerank_score) + for doc in child_docs: pid = doc.metadata.get("parent_id") if pid and pid not in parent_map: - parent_map[pid] = doc.metadata.get("score", 0.0) + # embedding 分数 + embedding_score = doc.metadata.get("score", 0.0) + # rerank 分数(如果有的话) + rerank_score = doc.metadata.get("rerank_score", 0.0) + parent_map[pid] = (embedding_score, rerank_score) if not parent_map: logger.warning("[Pipeline] 未找到 parent_id,返回子文档") @@ -94,10 +100,19 @@ class RAGPipeline: try: from backend.rag_core import create_docstore docstore, _ = create_docstore() - # 同步获取(异步版本不存在) parent_docs = docstore.mget(list(parent_map.keys())) - parent_map2 = {d.metadata.get("id"): d for d in parent_docs if d} - result = [(parent_map2[pid], score) for pid, score in parent_map.items() if pid in parent_map2] + + # 构建结果,保持分数信息 + result = [] + for doc in parent_docs: + if doc: + pid = doc.metadata.get("id") + scores = parent_map.get(pid, (0.0, 0.0)) + # 将分数添加到 metadata 中 + doc.metadata["embedding_score"] = scores[0] + doc.metadata["rerank_score"] = scores[1] + result.append((doc, scores[0] + scores[1] * 2)) # 综合分数,rerank 权重更高 + result.sort(key=lambda x: x[1], reverse=True) docs = [d for d, _ in result] logger.info(f"[Pipeline] 获取到 {len(docs)} 个父文档") diff --git a/backend/app/rag/rerank.py b/backend/app/rag/rerank.py index d63d303..dd0213c 100644 --- a/backend/app/rag/rerank.py +++ b/backend/app/rag/rerank.py @@ -49,44 +49,38 @@ class DocumentReranker: top_n: 返回前 N 个结果 Returns: - List[Document]: 排序后的文档列表 + List[Document]: 排序后的文档列表,每个文档的 metadata 中包含 rerank_score """ if not documents: return [] try: - # 1. 从 Document 提取内容(业务逻辑) + # 1. 从 Document 提取内容 doc_contents = [doc.page_content for doc in documents] - logger.info(f"[Rerank] 收到 {len(documents)} 个文档待重排, query={query[:50]}") - total_chars = sum(len(c) for c in doc_contents) - logger.info(f"[Rerank] 各文档长度: {[len(c) for c in doc_contents]}, 总字符数: {total_chars}") - # 粗略估算 tokens (中文约 0.75 tokens/字符) - estimated_tokens = int(total_chars * 0.75) - logger.info(f"[Rerank] 估算总 tokens: ~{estimated_tokens} (假设中文)") + logger.info(f"[Rerank] 收到 {len(documents)} 个文档待重排") - # 2. 调用纯服务层计算得分 - logger.info(f"[Rerank] 正在调用 rerank service: {type(self._rerank_service).__name__}") + # 2. 调用重排服务计算得分 scores = self._rerank_service.compute_scores(query, doc_contents) - logger.info(f"[Rerank] 获取到 {len(scores)} 个得分: {scores}") + logger.info(f"[Rerank] 获取到 {len(scores)} 个得分") - # 3. 根据得分排序(业务逻辑) + # 3. 构建 (文档, 分数) 对并排序 doc_score_pairs = list(zip(documents, scores)) doc_score_pairs_sorted = sorted(doc_score_pairs, key=lambda x: x[1], reverse=True) - logger.info(f"[Rerank] 排序后的结果:") - for i, (doc, score) in enumerate(doc_score_pairs_sorted): - logger.info(f" [{i}] score={score:.4f}, content={doc.page_content[:80]}...") - - # 4. 取 top_n - top_docs = [pair[0] for pair in doc_score_pairs_sorted[:top_n]] + # 4. 取 top_n,并添加 rerank_score 到 metadata + top_docs = [] + for doc, score in doc_score_pairs_sorted[:top_n]: + # 创建新文档,添加 rerank_score + new_doc = Document( + page_content=doc.page_content, + metadata={**doc.metadata, "rerank_score": score} + ) + top_docs.append(new_doc) return top_docs except Exception as e: - logger.warning(f"重排过程出错,返回原始前 {top_n} 个结果: {e}") - logger.warning(f"[Rerank] 异常详情: {type(e).__name__}: {e}") - import traceback - logger.warning(f"[Rerank] 堆栈: {traceback.format_exc()}") + logger.warning(f"[Rerank] 重排失败,返回原始结果: {e}") return documents[:top_n] diff --git a/backend/app/rag/retriever.py b/backend/app/rag/retriever.py index 5f70b4c..b50a903 100644 --- a/backend/app/rag/retriever.py +++ b/backend/app/rag/retriever.py @@ -22,7 +22,7 @@ from pydantic import Field, PrivateAttr from backend.rag_core import QdrantHybridStore, get_sparse_embedder, create_docstore from backend.rag_core.client import create_async_qdrant_client from ..model_services import get_embedding_service -from ..logger import info, warning, debug +from backend.app.logger import info, warning, debug # 模块级常量 diff --git a/backend/app/subgraphs/contact/nodes.py b/backend/app/subgraphs/contact/nodes.py index 19c7692..fcd0c29 100644 --- a/backend/app/subgraphs/contact/nodes.py +++ b/backend/app/subgraphs/contact/nodes.py @@ -8,7 +8,7 @@ from typing import Dict, Any from datetime import datetime # 公共工具 -from ...core import MarkdownFormatter +from backend.app.core import MarkdownFormatter from .state import ContactState from .api_client import ContactAPIClient diff --git a/backend/app/subgraphs/dictionary/nodes.py b/backend/app/subgraphs/dictionary/nodes.py index 3b094a6..b28756c 100644 --- a/backend/app/subgraphs/dictionary/nodes.py +++ b/backend/app/subgraphs/dictionary/nodes.py @@ -8,7 +8,7 @@ from datetime import datetime import random # 公共工具 -from ...core import ( +from backend.app.core import ( MarkdownFormatter ) diff --git a/backend/app/subgraphs/news_analysis/nodes.py b/backend/app/subgraphs/news_analysis/nodes.py index 99d6d17..44a9b48 100644 --- a/backend/app/subgraphs/news_analysis/nodes.py +++ b/backend/app/subgraphs/news_analysis/nodes.py @@ -7,7 +7,7 @@ from typing import Dict, Any from datetime import datetime # 公共工具 -from ...core import MarkdownFormatter +from backend.app.core import MarkdownFormatter from .state import ( NewsAnalysisState, diff --git a/backend/app/utils/logging.py b/backend/app/utils/logging.py index 9cd48b8..4ce7bc4 100644 --- a/backend/app/utils/logging.py +++ b/backend/app/utils/logging.py @@ -3,8 +3,8 @@ LangGraph 节点日志工具模块 提供状态流转追踪和 LLM 输入输出打印功能 """ -from ..config import ENABLE_GRAPH_TRACE -from ..logger import debug, info +from backend.app.config import ENABLE_GRAPH_TRACE +from backend.app.logger import debug, info from ..main_graph.state import MainGraphState diff --git a/tools/start.py b/tools/start.py index 1e3cc43..db56a5a 100755 --- a/tools/start.py +++ b/tools/start.py @@ -69,7 +69,7 @@ def cleanup(signum, frame): for i, proc in enumerate(processes): if proc.poll() is None: # 进程还在运行 proc.terminate() - proc.wait(timeout=5) + proc.wait(timeout=1) print(f"✓ 服务 {i+1} 已停止") sys.exit(0)