This commit is contained in:
@@ -1481,7 +1481,7 @@ mkdir backend/app/subgraphs/my_subgraph
|
||||
2. **创建状态定义 (state.py)**
|
||||
```python
|
||||
from typing_extensions import TypedDict
|
||||
from ..core.state_base import BaseSubgraphState
|
||||
from backend.app.core.state_base import BaseSubgraphState
|
||||
|
||||
class MySubgraphState(BaseSubgraphState):
|
||||
\"\"\"
|
||||
|
||||
@@ -16,8 +16,8 @@ from ..main_graph.main_graph_builder import build_react_main_graph
|
||||
from ..main_graph.tools.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
|
||||
from ..main_graph.config import set_stream_writer
|
||||
from ..main_graph.utils.rag_initializer import init_rag_tool
|
||||
from ..core.intent_classifier import get_intent_classifier
|
||||
from ..logger import debug, info, warning, error
|
||||
from backend.app.core.intent_classifier import get_intent_classifier
|
||||
from backend.app.logger import debug, info, warning, error
|
||||
from ..main_graph.state import MainGraphState, CurrentAction
|
||||
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any
|
||||
from ..logger import error # 保持兼容,或者替换为 logger
|
||||
from backend.app.logger import error # 保持兼容,或者替换为 logger
|
||||
|
||||
class ThreadHistoryService:
|
||||
"""线程历史查询服务"""
|
||||
|
||||
@@ -3,9 +3,10 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
|
||||
def create_system_prompt(tools: list = None) -> ChatPromptTemplate:
|
||||
"""
|
||||
创建系统提示模板,可选择动态注入工具描述。
|
||||
创建系统提示模板,整合多子系统能力、检索策略与回答规范。
|
||||
"""
|
||||
tools_section = ""
|
||||
# 构造工具描述
|
||||
tools_section = "无可用工具"
|
||||
if tools:
|
||||
tool_descs = []
|
||||
for tool in tools:
|
||||
@@ -14,25 +15,44 @@ def create_system_prompt(tools: list = None) -> ChatPromptTemplate:
|
||||
tool_descs.append(f"- {name}: {desc}")
|
||||
tools_section = "\n".join(tool_descs)
|
||||
|
||||
system_template = (
|
||||
"你是一个智能助手,具有三个专业子系统和RAG检索能力,请使用中文交流。\n\n"
|
||||
"【核心功能】\n"
|
||||
"1. 📚 词典/翻译子系统 - 查询单词、翻译文本、提取术语、每日一词\n"
|
||||
"2. 📰 资讯分析子系统 - 查询新闻、分析URL、提取关键词、生成报告\n"
|
||||
"3. 📇 通讯录子系统 - 查询联系人、添加联系人、管理通讯录\n"
|
||||
"4. 🔍 RAG检索 - 从知识库中检索相关信息回答问题\n\n"
|
||||
"【用户背景信息】\n"
|
||||
"以下是对当前用户的已知信息和长期记忆,你必须优先采纳:\n"
|
||||
"{memory_context}\n"
|
||||
"【可用工具与使用规则】\n"
|
||||
f"{tools_section}\n"
|
||||
"工具调用时请直接返回所需参数,无需额外说明。\n\n"
|
||||
"【回答要求(必须遵守)】\n"
|
||||
"1. 回答必须简洁、直接。\n"
|
||||
"2. 如果你认为该问题需要进行深入的推理或思考,请务必将你的思维链或推理过程用 `<think>` 和 `</think>` 标签包裹起来,放在回答的最前面。\n"
|
||||
"3. 优先利用已知用户信息进行个性化回复。\n"
|
||||
"4. 若无信息可依,礼貌询问或提供通用帮助。"
|
||||
)
|
||||
# 使用 f-string 将 tools_section 直接嵌入,而 memory_context 用双花括号转义保留为变量
|
||||
system_template = f'''你是一个智能助手,具备以下专业子系统和检索能力。请使用中文交流。
|
||||
|
||||
## 核心功能
|
||||
1. 📚 词典/翻译子系统 – 查询单词、翻译文本、提取术语、每日一词
|
||||
2. 📰 资讯分析子系统 – 查询新闻、分析URL、提取关键词、生成报告
|
||||
3. 📇 通讯录子系统 – 查询联系人、添加联系人、管理通讯录
|
||||
4. 🔍 RAG检索 – 从知识库中检索相关信息回答问题
|
||||
|
||||
## 检索与信息获取策略
|
||||
当收到用户问题时,请按以下优先级处理:
|
||||
1. **RAG 检索(第1次)**:首先尝试从知识库中查找答案。
|
||||
2. **ReRAG(第2次优化检索)**:如果第一次检索结果不相关或不充分,可以优化查询后再次进行 RAG 检索。
|
||||
3. **联网搜索**:如果两次 RAG 检索后仍无法获得满意答案,必须使用联网搜索获取最新信息。
|
||||
**重要约束**:
|
||||
- 最多进行 **2 次** RAG 检索尝试。
|
||||
- 第3次决定获取信息时,必须选择**联网搜索**,禁止无休止的本地检索。
|
||||
- 如果已经明确知识库不包含该信息(例如用户询问实时新闻),可以直接进入联网搜索。
|
||||
|
||||
## 可用工具
|
||||
{tools_section}
|
||||
工具调用时请直接返回所需参数,无需额外说明。
|
||||
|
||||
## 用户背景信息
|
||||
以下是当前用户的已知信息和长期记忆,你应在回答中优先利用这些信息进行个性化回复:
|
||||
{{memory_context}}
|
||||
若无相关信息,可礼貌询问或提供通用帮助。
|
||||
|
||||
## 回答要求(必须严格遵守)
|
||||
1. **来源标注**:回答开头必须明确标注信息来源,格式如下:
|
||||
- 使用知识库时:`【知识库:来源描述】`
|
||||
- 使用联网搜索时:`【联网搜索:来源描述】`
|
||||
- 若同时用到多个来源,按实际使用顺序标注,例如:`【知识库:三国演义】【联网搜索:百度百科】`
|
||||
2. **思维链**:如果问题需要深入推理或复杂思考,请将推理过程用 `<think>...</think>` 标签包裹,放在回答最前面(来源标注之前)。
|
||||
3. **简洁直接**:回答应重点突出、条理清晰,避免冗长。
|
||||
4. **个性化**:结合用户信息进行针对性回复。
|
||||
5. **无依据时**:若既无知识库支撑也无联网搜索结果,请如实说明无法回答,并建议用户提供更多信息或尝试其他方式。
|
||||
'''
|
||||
|
||||
return ChatPromptTemplate.from_messages([
|
||||
("system", system_template),
|
||||
|
||||
@@ -9,7 +9,7 @@ warnings.filterwarnings("ignore", category=DeprecationWarning, module="websocket
|
||||
warnings.filterwarnings("ignore", category=DeprecationWarning, module="uvicorn.protocols.websockets")
|
||||
|
||||
import os
|
||||
from ..config import DB_URI, BACKEND_PORT
|
||||
from backend.app.config import DB_URI, BACKEND_PORT
|
||||
import uuid
|
||||
import json
|
||||
from contextlib import asynccontextmanager
|
||||
@@ -22,18 +22,18 @@ from pydantic import BaseModel
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
from .agent.agent_service import AIAgentService, create_serde
|
||||
from .agent.history import ThreadHistoryService
|
||||
from ..core.human_review import (
|
||||
from backend.app.core.human_review import (
|
||||
ReviewManager,
|
||||
InMemoryReviewStore,
|
||||
ReviewStatus,
|
||||
HumanReview
|
||||
)
|
||||
from ..subgraphs.contact.api_client import ContactAPIClient
|
||||
from ..subgraphs.dictionary.api_client import DictionaryAPIClient
|
||||
from ..subgraphs.news_analysis.api_client import NewsAPIClient
|
||||
from backend.app.subgraphs.contact.api_client import ContactAPIClient
|
||||
from backend.app.subgraphs.dictionary.api_client import DictionaryAPIClient
|
||||
from backend.app.subgraphs.news_analysis.api_client import NewsAPIClient
|
||||
from .db.init_db import init_subgraph_tables
|
||||
from .db.models import ContactRepository, DictionaryRepository, NewsRepository
|
||||
from ..logger import info, error
|
||||
from backend.app.logger import info, error
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
|
||||
@@ -21,15 +21,15 @@ from .nodes.fast_paths import (
|
||||
fast_tool_node,
|
||||
)
|
||||
from .nodes.llm_call import create_dynamic_llm_call_node
|
||||
from .nodes.rag_nodes import rag_retrieve_node
|
||||
from .nodes.rag_nodes import rag_retrieve_node, check_rag_confidence
|
||||
from .nodes.retrieve_memory import create_retrieve_memory_node
|
||||
from .nodes.memory_trigger import memory_trigger_node, set_mem0_client
|
||||
from .nodes.summarize import create_summarize_node
|
||||
from .nodes.finalize import finalize_node
|
||||
from ..subgraphs.contact import build_contact_subgraph
|
||||
from ..subgraphs.dictionary import build_dictionary_subgraph
|
||||
from ..subgraphs.news_analysis import build_news_analysis_subgraph
|
||||
from ..logger import info
|
||||
from backend.app.subgraphs.contact import build_contact_subgraph
|
||||
from backend.app.subgraphs.dictionary import build_dictionary_subgraph
|
||||
from backend.app.subgraphs.news_analysis import build_news_analysis_subgraph
|
||||
from backend.app.logger import info
|
||||
|
||||
from .subgraph_wrapper import create_subgraph_nodes
|
||||
|
||||
@@ -198,8 +198,20 @@ def _add_react_loop_edges(graph: StateGraph, subgraph_nodes: Dict[str, Any]) ->
|
||||
}
|
||||
)
|
||||
|
||||
# RAG 检索后的置信度判断分支
|
||||
graph.add_conditional_edges(
|
||||
"rag_retrieve",
|
||||
check_rag_confidence,
|
||||
{
|
||||
"high_confidence": "llm_call", # 高置信度 → 直接生成回答
|
||||
"retry_rag": "rag_retrieve", # 低置信度 → 再次检索
|
||||
"low_confidence": "web_search", # 两次RAG后仍低 → 联网搜索
|
||||
"no_rag": "web_search", # 无结果 → 联网搜索
|
||||
}
|
||||
)
|
||||
|
||||
# 循环边(回到 react_reason)
|
||||
loop_back_nodes = ["rag_retrieve", "web_search", "handle_error"] + subgraph_names
|
||||
loop_back_nodes = ["web_search", "handle_error"] + subgraph_names
|
||||
for node_name in loop_back_nodes:
|
||||
graph.add_edge(node_name, "react_reason")
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from .web_search import web_search_node
|
||||
from .error_handling import error_handling_node
|
||||
from .routing import init_state_node, route_by_reasoning, should_summarize
|
||||
from .llm_call import create_dynamic_llm_call_node
|
||||
from .rag_nodes import rag_retrieve_node, rag_re_retrieve_node
|
||||
from .rag_nodes import rag_retrieve_node
|
||||
|
||||
# 记忆节点
|
||||
from .retrieve_memory import create_retrieve_memory_node
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
"""
|
||||
|
||||
from ...main_graph.state import MainGraphState, ErrorSeverity
|
||||
from ...logger import info
|
||||
from backend.app.logger import info
|
||||
|
||||
|
||||
def error_handling_node(state: MainGraphState) -> MainGraphState:
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import Optional
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from ..state import MainGraphState
|
||||
from ...logger import info, debug
|
||||
from backend.app.logger import info, debug
|
||||
from ...model_services.chat_services import get_small_llm_service, get_chat_service
|
||||
from .rag_nodes import rag_retrieve_node
|
||||
from ._utils import dispatch_custom_event
|
||||
@@ -113,10 +113,18 @@ async def fast_rag_node(state: MainGraphState, config: Optional[RunnableConfig]
|
||||
|
||||
|
||||
def _has_valid_rag_results(state: MainGraphState) -> bool:
|
||||
"""检查 RAG 结果是否有效"""
|
||||
rag_docs = getattr(state, "rag_docs", [])
|
||||
"""检查 RAG 结果是否有效(基于置信度)"""
|
||||
from .rag_nodes import RAG_CONFIDENCE_THRESHOLD
|
||||
rag_context = getattr(state, "rag_context", "")
|
||||
return (rag_docs and len(rag_docs) > 0) or (rag_context and len(rag_context) > 10)
|
||||
rag_confidence = getattr(state, "rag_confidence", 0.0)
|
||||
|
||||
# 有结果且置信度足够
|
||||
has_content = rag_context and len(rag_context) > 0
|
||||
has_confidence = rag_confidence >= RAG_CONFIDENCE_THRESHOLD
|
||||
|
||||
info(f"[Fast RAG Check] has_content={has_content}, rag_confidence={rag_confidence:.2f}, threshold={RAG_CONFIDENCE_THRESHOLD}")
|
||||
|
||||
return has_content and has_confidence
|
||||
|
||||
|
||||
async def _generate_fast_answer(state: MainGraphState, query: str) -> MainGraphState:
|
||||
|
||||
@@ -8,7 +8,7 @@ from typing import Any, Dict
|
||||
# 本地模块
|
||||
from ...main_graph.state import MainGraphState
|
||||
from ...utils.logging import log_state_change
|
||||
from ...logger import info, warning
|
||||
from backend.app.logger import info, warning
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ from datetime import datetime
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from ..state import MainGraphState
|
||||
from ...logger import info, debug
|
||||
from backend.app.logger import info, debug
|
||||
from ...model_services.chat_services import get_small_llm_service
|
||||
from ._utils import dispatch_custom_event
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ from langchain_core.messages import AIMessage
|
||||
from ...main_graph.state import MainGraphState
|
||||
from ...agent.prompts import create_system_prompt
|
||||
from ...utils.logging import log_state_change
|
||||
from ...logger import debug, info, error
|
||||
from backend.app.logger import debug, info, error
|
||||
|
||||
|
||||
def create_dynamic_llm_call_node(chat_services: Dict[str, BaseChatModel], tools: list):
|
||||
@@ -115,24 +115,7 @@ def create_dynamic_llm_call_node(chat_services: Dict[str, BaseChatModel], tools:
|
||||
):
|
||||
chunks.append(chunk)
|
||||
|
||||
info(f"[llm_call] LLM astream 完成,共收到 {len(chunks)} 个 chunks")
|
||||
for i, chunk in enumerate(chunks[:10]): # 只打印前10个避免日志过多
|
||||
chunk_type = type(chunk).__name__
|
||||
chunk_content = getattr(chunk, 'content', '') if hasattr(chunk, 'content') else str(chunk)
|
||||
# 打印更多属性
|
||||
additional_kwargs = getattr(chunk, 'additional_kwargs', {}) or {}
|
||||
response_metadata = getattr(chunk, 'response_metadata', {}) or {}
|
||||
# 打印所有属性
|
||||
info(f"[llm_call] chunk[{i}] type={chunk_type}")
|
||||
info(f"[llm_call] chunk[{i}] content长度={len(chunk_content) if chunk_content else 0}, content={repr(chunk_content[:200] if chunk_content else '')}")
|
||||
info(f"[llm_call] chunk[{i}] additional_kwargs={additional_kwargs}")
|
||||
info(f"[llm_call] chunk[{i}] response_metadata keys={list(response_metadata.keys()) if response_metadata else []}")
|
||||
info(f"[llm_call] chunk[{i}] response_metadata={response_metadata}")
|
||||
# 检查是否有其他可能存储内容的属性
|
||||
if hasattr(chunk, 'tool_call_chunks'):
|
||||
info(f"[llm_call] chunk[{i}] tool_call_chunks={chunk.tool_call_chunks}")
|
||||
if hasattr(chunk, 'usage_metadata'):
|
||||
info(f"[llm_call] chunk[{i}] usage_metadata={chunk.usage_metadata}")
|
||||
info(f"[llm_call] LLM astream 完成,共收到 {len(chunks)} 个 chunks,info:{chunks}")
|
||||
|
||||
# 将所有 chunk 合并成最终的 AIMessage
|
||||
if chunks:
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import Any, Dict
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from ...main_graph.state import MainGraphState
|
||||
from ...memory.mem0_client import Mem0Client
|
||||
from ...logger import info
|
||||
from backend.app.logger import info
|
||||
|
||||
|
||||
# 全局变量,在 GraphBuilder 中注入
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""
|
||||
RAG 检索节点模块
|
||||
使用模块级变量管理 RAG 工具
|
||||
包含:RAG 检索、置信度判断、重检索等节点
|
||||
"""
|
||||
|
||||
import time
|
||||
@@ -11,10 +11,15 @@ from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from ...main_graph.state import MainGraphState, ErrorRecord, ErrorSeverity
|
||||
from ...main_graph.utils.retry_utils import RAG_RETRY_CONFIG
|
||||
from ...logger import info
|
||||
from backend.app.logger import info, debug
|
||||
from ...model_services import get_small_llm_service
|
||||
from ._utils import dispatch_custom_event, make_react_event
|
||||
|
||||
|
||||
# 置信度阈值配置
|
||||
RAG_CONFIDENCE_THRESHOLD = 0.6 # 低于此值认为检索不相关
|
||||
|
||||
|
||||
def _get_rag_tool() -> Optional[callable]:
|
||||
"""获取 RAG 工具"""
|
||||
from backend.app.main_graph.utils.rag_initializer import get_rag_tool
|
||||
@@ -36,43 +41,27 @@ async def _rag_retrieve_core(state: MainGraphState, rag_tool: callable) -> MainG
|
||||
# 调用 RAG 工具
|
||||
rag_context = await rag_tool.ainvoke(retrieval_query)
|
||||
info(f"[RAG Core] 获取到 rag_context: {type(rag_context)}, 长度={len(rag_context) if rag_context else 0}")
|
||||
info(f"[RAG Core] ========== RAG 返回的知识内容 ==========")
|
||||
info(f"{rag_context[:500]}..." if len(rag_context) > 500 else rag_context)
|
||||
info(f"[RAG Core] ========================================")
|
||||
|
||||
# 更新状态
|
||||
state.rag_context = rag_context
|
||||
state.rag_docs = [{"source": "rag_retrieval", "content": rag_context}]
|
||||
state.rag_retrieved = True
|
||||
state.success = True
|
||||
state.rag_attempts = getattr(state, 'rag_attempts', 0) + 1
|
||||
state.debug_info["rag_source"] = "tool"
|
||||
|
||||
info(f"[RAG Core] state.rag_docs 长度: {len(state.rag_docs)}")
|
||||
return state
|
||||
|
||||
|
||||
# ========== RAG 检索节点 ==========
|
||||
async def rag_retrieve_node(state: MainGraphState, config: Optional[RunnableConfig] = None) -> MainGraphState:
|
||||
"""RAG 检索节点:带超时和重试"""
|
||||
"""RAG 检索节点:检索 + 置信度评估"""
|
||||
state.current_phase = "rag_retrieving"
|
||||
start_time = time.time()
|
||||
last_error = None
|
||||
|
||||
# 获取 RAG 工具
|
||||
rag_tool = _get_rag_tool()
|
||||
|
||||
if not rag_tool:
|
||||
error_record = ErrorRecord(
|
||||
error_type="RAGRetrievalError",
|
||||
error_message="RAG 工具未初始化",
|
||||
severity=ErrorSeverity.WARNING,
|
||||
source="rag_retrieve_node",
|
||||
timestamp=datetime.now().isoformat(),
|
||||
retry_count=0,
|
||||
max_retries=RAG_RETRY_CONFIG.max_retries,
|
||||
)
|
||||
state.errors.append(error_record)
|
||||
state.current_error = error_record
|
||||
state.current_phase = "error_handling"
|
||||
info("[RAG] RAG 工具未初始化")
|
||||
state.rag_confidence = 0.0
|
||||
state.rag_retrieved = False
|
||||
return state
|
||||
|
||||
await dispatch_custom_event(
|
||||
@@ -81,99 +70,184 @@ async def rag_retrieve_node(state: MainGraphState, config: Optional[RunnableConf
|
||||
config
|
||||
)
|
||||
|
||||
for attempt in range(RAG_RETRY_CONFIG.max_retries + 1):
|
||||
try:
|
||||
result = await _rag_retrieve_core(state, rag_tool)
|
||||
try:
|
||||
state = await _rag_retrieve_core(state, rag_tool)
|
||||
|
||||
info(f"[RAG] 检索成功,上下文长度: {len(result.rag_context)} 字符")
|
||||
# 评估置信度
|
||||
confidence = await _evaluate_rag_confidence(state)
|
||||
state.rag_confidence = confidence
|
||||
|
||||
state.debug_info["rag_retrieval"] = {
|
||||
"attempt": attempt + 1,
|
||||
"success": True,
|
||||
"time": time.time() - start_time
|
||||
}
|
||||
info(f"[RAG] 检索完成,置信度={confidence:.2f},RAG尝试次数={state.rag_attempts}")
|
||||
|
||||
state.reasoning_history.append({
|
||||
"step": state.reasoning_step,
|
||||
"action": "RETRIEVE_RAG",
|
||||
"confidence": 1.0,
|
||||
"reasoning": "RAG 检索完成",
|
||||
"timestamp": datetime.now().isoformat()
|
||||
})
|
||||
state.reasoning_history.append({
|
||||
"step": state.reasoning_step,
|
||||
"action": "RETRIEVE_RAG",
|
||||
"confidence": confidence,
|
||||
"reasoning": f"RAG 检索完成,置信度={confidence:.2f}",
|
||||
"timestamp": datetime.now().isoformat()
|
||||
})
|
||||
|
||||
doc_count = len(result.rag_docs) if result.rag_docs else 0
|
||||
await dispatch_custom_event(
|
||||
"react_reasoning",
|
||||
make_react_event(state.reasoning_step, "rag_retrieve_complete", 1.0,
|
||||
f"RAG 检索完成,找到 {doc_count} 条相关文档"),
|
||||
config
|
||||
)
|
||||
await dispatch_custom_event(
|
||||
"react_reasoning",
|
||||
make_react_event(state.reasoning_step, "rag_retrieve_complete", confidence,
|
||||
f"RAG 检索完成,置信度={confidence:.2f}"),
|
||||
config
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
|
||||
if attempt >= RAG_RETRY_CONFIG.max_retries:
|
||||
break
|
||||
|
||||
await dispatch_custom_event(
|
||||
"react_reasoning",
|
||||
make_react_event(state.reasoning_step, "rag_retrieve_retry", 1.0,
|
||||
f"RAG 检索失败,第 {attempt + 1} 次重试..."),
|
||||
config
|
||||
)
|
||||
|
||||
delay = RAG_RETRY_CONFIG.base_delay * (2 ** attempt)
|
||||
await asyncio.sleep(min(delay, RAG_RETRY_CONFIG.max_delay))
|
||||
|
||||
# 失败记录
|
||||
state.reasoning_history.append({
|
||||
"step": state.reasoning_step,
|
||||
"action": "RETRIEVE_RAG",
|
||||
"confidence": 0.0,
|
||||
"reasoning": f"RAG 检索失败: {str(last_error) if last_error else '超时'}",
|
||||
"timestamp": datetime.now().isoformat()
|
||||
})
|
||||
|
||||
error_record = ErrorRecord(
|
||||
error_type="RAGRetrievalError",
|
||||
error_message=str(last_error) if last_error else "RAG 检索超时",
|
||||
severity=ErrorSeverity.WARNING,
|
||||
source="rag_retrieve_node",
|
||||
timestamp=datetime.now().isoformat(),
|
||||
retry_count=RAG_RETRY_CONFIG.max_retries,
|
||||
max_retries=RAG_RETRY_CONFIG.max_retries,
|
||||
)
|
||||
|
||||
state.errors.append(error_record)
|
||||
state.current_error = error_record
|
||||
state.current_phase = "error_handling"
|
||||
|
||||
await dispatch_custom_event(
|
||||
"react_reasoning",
|
||||
make_react_event(state.reasoning_step, "rag_retrieve_error", 1.0,
|
||||
f"RAG 检索失败: {str(last_error)}"),
|
||||
config
|
||||
)
|
||||
except Exception as e:
|
||||
info(f"[RAG] 检索失败: {e}")
|
||||
state.rag_confidence = 0.0
|
||||
state.rag_retrieved = False
|
||||
|
||||
return state
|
||||
|
||||
|
||||
async def rag_re_retrieve_node(state: MainGraphState, config: Optional[RunnableConfig] = None) -> MainGraphState:
|
||||
"""重新检索节点"""
|
||||
state.current_phase = "rag_re_retrieving"
|
||||
async def _evaluate_rag_confidence(state: MainGraphState) -> float:
|
||||
"""评估 RAG 检索结果置信度(综合向量相似度 + 重排分数 + 小模型判断)"""
|
||||
query = state.user_query or ""
|
||||
rag_context = state.rag_context or ""
|
||||
|
||||
state.debug_info["rag_re_retrieve"] = {
|
||||
"original_retrieved": state.rag_retrieved,
|
||||
"original_docs_count": len(state.rag_docs)
|
||||
}
|
||||
if not rag_context:
|
||||
return 0.0
|
||||
|
||||
return await rag_retrieve_node(state, config)
|
||||
# 方式1: 向量相似度(从 rag_docs 中获取)
|
||||
embedding_score = _get_embedding_similarity(state, query)
|
||||
info(f"[RAG Confidence] 向量相似度={embedding_score:.3f}")
|
||||
|
||||
# 方式2: 重排序分数(从 rag_docs 中获取)
|
||||
rerank_score = _get_rerank_score(state)
|
||||
info(f"[RAG Confidence] 重排分数={rerank_score:.3f}")
|
||||
|
||||
# 方式3: 小模型判断
|
||||
llm_score = await _get_llm_score(state)
|
||||
info(f"[RAG Confidence] LLM评估={llm_score:.3f}")
|
||||
|
||||
# 综合得分(加权平均)
|
||||
# 向量相似度权重 0.3,重排权重 0.3,LLM 权重 0.4
|
||||
final_score = embedding_score * 0.3 + rerank_score * 0.3 + llm_score * 0.4
|
||||
info(f"[RAG Confidence] 综合置信度={final_score:.3f} (embedding={embedding_score:.3f}*0.3 + rerank={rerank_score:.3f}*0.3 + llm={llm_score:.3f}*0.4)")
|
||||
|
||||
return final_score
|
||||
|
||||
|
||||
def _get_embedding_similarity(state: MainGraphState) -> float:
|
||||
"""从 rag_docs 中获取向量相似度分数"""
|
||||
rag_docs = getattr(state, "rag_docs", [])
|
||||
|
||||
# 如果有多个文档,取最高分
|
||||
scores = []
|
||||
for doc in rag_docs:
|
||||
if isinstance(doc, dict):
|
||||
score = doc.get("score", 0.0)
|
||||
# 向量相似度通常在 0-1 之间,RRF 分数可能更高
|
||||
# 归一化到 0-1
|
||||
if score > 1.0:
|
||||
score = min(score / 10.0, 1.0) # 假设 max 约 10
|
||||
scores.append(score)
|
||||
elif hasattr(doc, "metadata"):
|
||||
score = doc.metadata.get("score", 0.0)
|
||||
if score > 1.0:
|
||||
score = min(score / 10.0, 1.0)
|
||||
scores.append(score)
|
||||
|
||||
if scores:
|
||||
# 取平均或最高分
|
||||
return max(scores) # 使用最高分更准确
|
||||
return 0.0
|
||||
|
||||
|
||||
def _get_rerank_score(state: MainGraphState) -> float:
|
||||
"""从 rag_docs 中获取重排序分数"""
|
||||
rag_docs = getattr(state, "rag_docs", [])
|
||||
|
||||
# 重排分数通常在 0-1 之间
|
||||
scores = []
|
||||
for doc in rag_docs:
|
||||
if isinstance(doc, dict):
|
||||
score = doc.get("rerank_score", 0.0)
|
||||
elif hasattr(doc, "metadata"):
|
||||
score = doc.metadata.get("rerank_score", 0.0)
|
||||
else:
|
||||
score = 0.0
|
||||
|
||||
if score > 0:
|
||||
scores.append(score)
|
||||
|
||||
if scores:
|
||||
return max(scores) # 使用最高分
|
||||
return 0.0
|
||||
|
||||
|
||||
async def _get_llm_score(state: MainGraphState) -> float:
|
||||
"""使用小模型评估检索结果相关性"""
|
||||
query = state.user_query or ""
|
||||
rag_context = state.rag_context or ""
|
||||
|
||||
try:
|
||||
llm = get_small_llm_service()
|
||||
prompt = f"""评估以下检索结果与用户问题的相关性,返回 0.0-1.0 的分数:
|
||||
- 1.0 = 完全相关,能直接回答问题
|
||||
- 0.5 = 部分相关,有一定参考价值
|
||||
- 0.0 = 完全不相关,无法回答问题
|
||||
|
||||
用户问题:{query}
|
||||
|
||||
检索结果:{rag_context[:1500]}
|
||||
|
||||
只返回一个数字:"""
|
||||
|
||||
response = await llm.ainvoke(prompt)
|
||||
content = response.content.strip()
|
||||
|
||||
import re
|
||||
match = re.search(r'(\d+\.?\d*)', content)
|
||||
if match:
|
||||
score = float(match.group(1))
|
||||
return max(0.0, min(1.0, score))
|
||||
|
||||
except Exception as e:
|
||||
info(f"[RAG Confidence] LLM评估失败: {e}")
|
||||
|
||||
return 0.5 # 默认中等置信度
|
||||
|
||||
|
||||
# ========== 置信度判断节点 ==========
|
||||
def check_rag_confidence(state: MainGraphState) -> str:
|
||||
"""
|
||||
根据 RAG 置信度判断下一步
|
||||
|
||||
Returns:
|
||||
"high_confidence" - 高置信度(>=0.6),可直接生成回答
|
||||
"low_confidence" - 低置信度(<0.6),需要联网搜索
|
||||
"no_rag" - 无检索结果,需要联网搜索
|
||||
"""
|
||||
rag_attempts = getattr(state, 'rag_attempts', 0)
|
||||
rag_confidence = getattr(state, 'rag_confidence', 0.0)
|
||||
|
||||
info(f"[Confidence Check] rag_attempts={rag_attempts}, rag_confidence={rag_confidence:.2f}")
|
||||
|
||||
# 情况1: 没有检索结果
|
||||
if not getattr(state, 'rag_retrieved', False) or not state.rag_context:
|
||||
info("[Confidence Check] 无检索结果,走联网")
|
||||
return "no_rag"
|
||||
|
||||
# 情况2: 置信度低于阈值
|
||||
if rag_confidence < RAG_CONFIDENCE_THRESHOLD:
|
||||
if rag_attempts >= 2:
|
||||
info(f"[Confidence Check] 置信度={rag_confidence:.2f}<{RAG_CONFIDENCE_THRESHOLD},且RAG尝试{rag_attempts}次,走联网")
|
||||
return "low_confidence"
|
||||
else:
|
||||
info(f"[Confidence Check] 置信度={rag_confidence:.2f}<{RAG_CONFIDENCE_THRESHOLD},可再尝试RAG一次")
|
||||
return "retry_rag"
|
||||
|
||||
# 情况3: 高置信度
|
||||
info(f"[Confidence Check] 高置信度={rag_confidence:.2f}>={RAG_CONFIDENCE_THRESHOLD},直接生成回答")
|
||||
return "high_confidence"
|
||||
|
||||
|
||||
# ========== 导出 ==========
|
||||
__all__ = [
|
||||
"rag_retrieve_node",
|
||||
"rag_re_retrieve_node",
|
||||
"check_rag_confidence",
|
||||
"RAG_CONFIDENCE_THRESHOLD",
|
||||
]
|
||||
|
||||
@@ -7,9 +7,9 @@ from typing import Optional
|
||||
from datetime import datetime
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from ...core.intent import react_reason_async, ReasoningResult
|
||||
from backend.app.core.intent import react_reason_async, ReasoningResult
|
||||
from ...main_graph.state import MainGraphState
|
||||
from ...logger import info
|
||||
from backend.app.logger import info
|
||||
from ._utils import dispatch_custom_event, make_react_event
|
||||
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ from typing import Any, Dict
|
||||
from ...main_graph.state import MainGraphState
|
||||
from ...memory.mem0_client import Mem0Client
|
||||
from ...utils.logging import log_state_change
|
||||
from ...logger import debug
|
||||
from backend.app.logger import debug
|
||||
|
||||
|
||||
def create_retrieve_memory_node(mem0_client: Mem0Client):
|
||||
|
||||
@@ -10,9 +10,9 @@
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from ...core.intent import get_route_by_reasoning, ReasoningAction
|
||||
from backend.app.core.intent import get_route_by_reasoning, ReasoningAction
|
||||
from ...main_graph.state import MainGraphState
|
||||
from ...logger import info
|
||||
from backend.app.logger import info
|
||||
|
||||
|
||||
# ========== 初始化状态节点 ==========
|
||||
@@ -97,6 +97,12 @@ def route_by_reasoning(state: MainGraphState) -> str:
|
||||
}
|
||||
target = route_mapping.get(route, "llm_call")
|
||||
|
||||
# ========== RAG 次数硬限制 ==========
|
||||
rag_attempts = getattr(state, 'rag_attempts', 0)
|
||||
if target == "rag_retrieve" and rag_attempts >= 2:
|
||||
info(f"[条件路由] RAG已尝试{rag_attempts}次,强制走联网搜索")
|
||||
target = "web_search"
|
||||
|
||||
# ========== 循环防护检测 ==========
|
||||
# 1. 路由模式检测(A→B→A→B 交替)
|
||||
if len(previous_actions) >= 4:
|
||||
|
||||
@@ -9,7 +9,7 @@ from typing import Any, Dict
|
||||
from ...main_graph.state import MainGraphState
|
||||
from ...memory.mem0_client import Mem0Client
|
||||
from ...utils.logging import log_state_change
|
||||
from ...logger import debug, info, error, warning
|
||||
from backend.app.logger import debug, info, error, warning
|
||||
|
||||
|
||||
def create_summarize_node(mem0_client: Mem0Client):
|
||||
|
||||
@@ -11,7 +11,7 @@ from ...main_graph.config import get_stream_writer
|
||||
# 本地模块
|
||||
from ...main_graph.state import MainGraphState
|
||||
from ...utils.logging import log_state_change
|
||||
from ...logger import debug, info
|
||||
from backend.app.logger import debug, info
|
||||
|
||||
def create_tool_call_node(tools_by_name: Dict[str, Any]):
|
||||
"""
|
||||
|
||||
@@ -7,7 +7,7 @@ from datetime import datetime
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from ...main_graph.state import MainGraphState, ErrorRecord, ErrorSeverity
|
||||
from ...logger import info
|
||||
from backend.app.logger import info
|
||||
|
||||
|
||||
async def web_search_node(state: MainGraphState, config: Optional[RunnableConfig] = None) -> MainGraphState:
|
||||
|
||||
@@ -75,6 +75,8 @@ class MainGraphState:
|
||||
rag_context: str = ""
|
||||
rag_retrieved: bool = False
|
||||
rag_docs: List[Dict[str, Any]] = field(default_factory=list)
|
||||
rag_confidence: float = 0.0 # RAG 检索置信度 (0.0-1.0)
|
||||
rag_attempts: int = 0 # RAG 检索次数统计
|
||||
|
||||
# ========== 联网搜索相关字段 ==========
|
||||
web_search_results: List[str] = field(default_factory=list)
|
||||
|
||||
@@ -8,7 +8,7 @@ from datetime import datetime
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from .state import MainGraphState, ErrorRecord, ErrorSeverity
|
||||
from ..logger import info
|
||||
from backend.app.logger import info
|
||||
|
||||
|
||||
def wrap_subgraph_for_error_handling(subgraph, name: str):
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
from ...rag.tools import create_rag_tool
|
||||
from ...rag.retriever import create_parent_hybrid_retriever
|
||||
from ...model_services import get_embedding_service
|
||||
from ...logger import info, warning
|
||||
from backend.app.logger import info, warning
|
||||
import sys
|
||||
|
||||
# 全局 RAG 工具
|
||||
|
||||
@@ -9,7 +9,7 @@ from typing import Optional, List
|
||||
|
||||
from mem0 import AsyncMemory
|
||||
|
||||
from ..config import (
|
||||
from backend.app.config import (
|
||||
LLM_API_KEY,
|
||||
ZHIPUAI_API_KEY,
|
||||
VLLM_BASE_URL,
|
||||
@@ -21,7 +21,7 @@ from ..config import (
|
||||
ZHIPU_EMBEDDING_MODEL,
|
||||
ZHIPU_API_BASE,
|
||||
)
|
||||
from ..logger import info, warning, error
|
||||
from backend.app.logger import info, warning, error
|
||||
from ..model_services import get_embedding_service
|
||||
from ..model_services.chat_services import get_chat_service
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ from .base import (
|
||||
FallbackServiceChain,
|
||||
SingletonServiceManager
|
||||
)
|
||||
from ..config import (
|
||||
from backend.app.config import (
|
||||
VLLM_BASE_URL,
|
||||
LLM_API_KEY,
|
||||
ZHIPUAI_API_KEY,
|
||||
@@ -203,7 +203,7 @@ class LocalSmallModelProvider(BaseServiceProvider[BaseChatModel]):
|
||||
"""
|
||||
|
||||
def __init__(self, model: str = None):
|
||||
from ..config import SMALL_LOCAL_MODEL_NAME, SMALL_VLLM_BASE_URL, SMALL_LLM_API_KEY
|
||||
from backend.app.config import SMALL_LOCAL_MODEL_NAME, SMALL_VLLM_BASE_URL, SMALL_LLM_API_KEY
|
||||
super().__init__("local_small")
|
||||
self._model = model or SMALL_LOCAL_MODEL_NAME
|
||||
self._base_url = SMALL_VLLM_BASE_URL
|
||||
@@ -242,7 +242,7 @@ class DeepSeekSmallModelProvider(BaseServiceProvider[BaseChatModel]):
|
||||
"""
|
||||
|
||||
def __init__(self, model: str = None):
|
||||
from ..config import SMALL_DEEPSEEK_MODEL, SMALL_DEEPSEEK_API_KEY, SMALL_DEEPSEEK_API_BASE
|
||||
from backend.app.config import SMALL_DEEPSEEK_MODEL, SMALL_DEEPSEEK_API_KEY, SMALL_DEEPSEEK_API_BASE
|
||||
super().__init__("deepseek_small")
|
||||
self._model = model or SMALL_DEEPSEEK_MODEL
|
||||
self._api_key = SMALL_DEEPSEEK_API_KEY
|
||||
|
||||
@@ -21,7 +21,7 @@ from .base import (
|
||||
FallbackServiceChain,
|
||||
SingletonServiceManager
|
||||
)
|
||||
from ..config import (
|
||||
from backend.app.config import (
|
||||
LLAMACPP_EMBEDDING_URL,
|
||||
LLAMACPP_API_KEY,
|
||||
ZHIPUAI_API_KEY,
|
||||
|
||||
@@ -27,7 +27,7 @@ from .base import (
|
||||
FallbackServiceChain,
|
||||
SingletonServiceManager
|
||||
)
|
||||
from ..config import (
|
||||
from backend.app.config import (
|
||||
LLAMACPP_RERANKER_URL,
|
||||
LLAMACPP_API_KEY,
|
||||
ZHIPUAI_API_KEY,
|
||||
|
||||
@@ -81,11 +81,17 @@ class RAGPipeline:
|
||||
return await self.retriever.ainvoke(query)
|
||||
|
||||
async def _get_parents(self, child_docs: List[Document]) -> List[Document]:
|
||||
parent_map = {}
|
||||
# 收集 parent_id 和对应的分数
|
||||
parent_map = {} # parent_id -> (embedding_score, rerank_score)
|
||||
|
||||
for doc in child_docs:
|
||||
pid = doc.metadata.get("parent_id")
|
||||
if pid and pid not in parent_map:
|
||||
parent_map[pid] = doc.metadata.get("score", 0.0)
|
||||
# embedding 分数
|
||||
embedding_score = doc.metadata.get("score", 0.0)
|
||||
# rerank 分数(如果有的话)
|
||||
rerank_score = doc.metadata.get("rerank_score", 0.0)
|
||||
parent_map[pid] = (embedding_score, rerank_score)
|
||||
|
||||
if not parent_map:
|
||||
logger.warning("[Pipeline] 未找到 parent_id,返回子文档")
|
||||
@@ -94,10 +100,19 @@ class RAGPipeline:
|
||||
try:
|
||||
from backend.rag_core import create_docstore
|
||||
docstore, _ = create_docstore()
|
||||
# 同步获取(异步版本不存在)
|
||||
parent_docs = docstore.mget(list(parent_map.keys()))
|
||||
parent_map2 = {d.metadata.get("id"): d for d in parent_docs if d}
|
||||
result = [(parent_map2[pid], score) for pid, score in parent_map.items() if pid in parent_map2]
|
||||
|
||||
# 构建结果,保持分数信息
|
||||
result = []
|
||||
for doc in parent_docs:
|
||||
if doc:
|
||||
pid = doc.metadata.get("id")
|
||||
scores = parent_map.get(pid, (0.0, 0.0))
|
||||
# 将分数添加到 metadata 中
|
||||
doc.metadata["embedding_score"] = scores[0]
|
||||
doc.metadata["rerank_score"] = scores[1]
|
||||
result.append((doc, scores[0] + scores[1] * 2)) # 综合分数,rerank 权重更高
|
||||
|
||||
result.sort(key=lambda x: x[1], reverse=True)
|
||||
docs = [d for d, _ in result]
|
||||
logger.info(f"[Pipeline] 获取到 {len(docs)} 个父文档")
|
||||
|
||||
@@ -49,44 +49,38 @@ class DocumentReranker:
|
||||
top_n: 返回前 N 个结果
|
||||
|
||||
Returns:
|
||||
List[Document]: 排序后的文档列表
|
||||
List[Document]: 排序后的文档列表,每个文档的 metadata 中包含 rerank_score
|
||||
"""
|
||||
if not documents:
|
||||
return []
|
||||
|
||||
try:
|
||||
# 1. 从 Document 提取内容(业务逻辑)
|
||||
# 1. 从 Document 提取内容
|
||||
doc_contents = [doc.page_content for doc in documents]
|
||||
logger.info(f"[Rerank] 收到 {len(documents)} 个文档待重排, query={query[:50]}")
|
||||
total_chars = sum(len(c) for c in doc_contents)
|
||||
logger.info(f"[Rerank] 各文档长度: {[len(c) for c in doc_contents]}, 总字符数: {total_chars}")
|
||||
# 粗略估算 tokens (中文约 0.75 tokens/字符)
|
||||
estimated_tokens = int(total_chars * 0.75)
|
||||
logger.info(f"[Rerank] 估算总 tokens: ~{estimated_tokens} (假设中文)")
|
||||
logger.info(f"[Rerank] 收到 {len(documents)} 个文档待重排")
|
||||
|
||||
# 2. 调用纯服务层计算得分
|
||||
logger.info(f"[Rerank] 正在调用 rerank service: {type(self._rerank_service).__name__}")
|
||||
# 2. 调用重排服务计算得分
|
||||
scores = self._rerank_service.compute_scores(query, doc_contents)
|
||||
logger.info(f"[Rerank] 获取到 {len(scores)} 个得分: {scores}")
|
||||
logger.info(f"[Rerank] 获取到 {len(scores)} 个得分")
|
||||
|
||||
# 3. 根据得分排序(业务逻辑)
|
||||
# 3. 构建 (文档, 分数) 对并排序
|
||||
doc_score_pairs = list(zip(documents, scores))
|
||||
doc_score_pairs_sorted = sorted(doc_score_pairs, key=lambda x: x[1], reverse=True)
|
||||
|
||||
logger.info(f"[Rerank] 排序后的结果:")
|
||||
for i, (doc, score) in enumerate(doc_score_pairs_sorted):
|
||||
logger.info(f" [{i}] score={score:.4f}, content={doc.page_content[:80]}...")
|
||||
|
||||
# 4. 取 top_n
|
||||
top_docs = [pair[0] for pair in doc_score_pairs_sorted[:top_n]]
|
||||
# 4. 取 top_n,并添加 rerank_score 到 metadata
|
||||
top_docs = []
|
||||
for doc, score in doc_score_pairs_sorted[:top_n]:
|
||||
# 创建新文档,添加 rerank_score
|
||||
new_doc = Document(
|
||||
page_content=doc.page_content,
|
||||
metadata={**doc.metadata, "rerank_score": score}
|
||||
)
|
||||
top_docs.append(new_doc)
|
||||
|
||||
return top_docs
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"重排过程出错,返回原始前 {top_n} 个结果: {e}")
|
||||
logger.warning(f"[Rerank] 异常详情: {type(e).__name__}: {e}")
|
||||
import traceback
|
||||
logger.warning(f"[Rerank] 堆栈: {traceback.format_exc()}")
|
||||
logger.warning(f"[Rerank] 重排失败,返回原始结果: {e}")
|
||||
return documents[:top_n]
|
||||
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ from pydantic import Field, PrivateAttr
|
||||
from backend.rag_core import QdrantHybridStore, get_sparse_embedder, create_docstore
|
||||
from backend.rag_core.client import create_async_qdrant_client
|
||||
from ..model_services import get_embedding_service
|
||||
from ..logger import info, warning, debug
|
||||
from backend.app.logger import info, warning, debug
|
||||
|
||||
|
||||
# 模块级常量
|
||||
|
||||
@@ -8,7 +8,7 @@ from typing import Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
# 公共工具
|
||||
from ...core import MarkdownFormatter
|
||||
from backend.app.core import MarkdownFormatter
|
||||
|
||||
from .state import ContactState
|
||||
from .api_client import ContactAPIClient
|
||||
|
||||
@@ -8,7 +8,7 @@ from datetime import datetime
|
||||
import random
|
||||
|
||||
# 公共工具
|
||||
from ...core import (
|
||||
from backend.app.core import (
|
||||
MarkdownFormatter
|
||||
)
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
# 公共工具
|
||||
from ...core import MarkdownFormatter
|
||||
from backend.app.core import MarkdownFormatter
|
||||
|
||||
from .state import (
|
||||
NewsAnalysisState,
|
||||
|
||||
@@ -3,8 +3,8 @@ LangGraph 节点日志工具模块
|
||||
提供状态流转追踪和 LLM 输入输出打印功能
|
||||
"""
|
||||
|
||||
from ..config import ENABLE_GRAPH_TRACE
|
||||
from ..logger import debug, info
|
||||
from backend.app.config import ENABLE_GRAPH_TRACE
|
||||
from backend.app.logger import debug, info
|
||||
from ..main_graph.state import MainGraphState
|
||||
|
||||
|
||||
|
||||
@@ -69,7 +69,7 @@ def cleanup(signum, frame):
|
||||
for i, proc in enumerate(processes):
|
||||
if proc.poll() is None: # 进程还在运行
|
||||
proc.terminate()
|
||||
proc.wait(timeout=5)
|
||||
proc.wait(timeout=1)
|
||||
print(f"✓ 服务 {i+1} 已停止")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user