diff --git a/README.md b/README.md index 022c23b..00d2818 100644 --- a/README.md +++ b/README.md @@ -1154,7 +1154,7 @@ RAG 系统分为两个独立但协同的阶段: ❌ 只能捕捉"语义相似",专有名词匹配差 实现代码: - from app.rag.retriever import create_base_retriever + from backend.app.rag.retriever import create_base_retriever retriever = create_base_retriever( collection_name="rag_documents", @@ -1175,7 +1175,7 @@ RAG 系统分为两个独立但协同的阶段: 两路结果并行获取,等待融合 实现代码: - from app.rag.retriever import create_hybrid_retriever + from backend.app.rag.retriever import create_hybrid_retriever retriever = create_hybrid_retriever( collection_name="rag_documents", @@ -1195,7 +1195,7 @@ RAG 系统分为两个独立但协同的阶段: 由模型直接输出 0~1 的相关性得分,精度极高 实现代码: - from app.rag.reranker import LLaMaCPPReranker + from backend.app.rag.reranker import LLaMaCPPReranker reranker = LLaMaCPPReranker( base_url="http://127.0.0.1:8083", @@ -1215,7 +1215,7 @@ RAG 系统分为两个独立但协同的阶段: 通过 LLM 将单一问题改写为多个不同角度的查询 实现代码: - from app.rag.query_transform import MultiQueryGenerator + from backend.app.rag.query_transform import MultiQueryGenerator generator = MultiQueryGenerator(llm=llm, num_queries=3) queries = await generator.agenerate("如何申请项目资金?") @@ -1231,7 +1231,7 @@ RAG 系统分为两个独立但协同的阶段: 有效避免某一极端检索结果主导全局 实现代码: - from app.rag.fusion import reciprocal_rank_fusion + from backend.app.rag.fusion import reciprocal_rank_fusion # 多个查询的检索结果 doc_lists = [result1, result2, result3] @@ -1260,8 +1260,8 @@ RAG 系统分为两个独立但协同的阶段: └────────── └────────────── └──────────┘ └────────┘ 实现代码: - from app.rag.tools import search_knowledge_base - from app.main_graph.utils.main_graph_builder import MainGraphBuilder + from backend.app.rag.tools import search_knowledge_base + from backend.app.main_graph.utils.main_graph_builder import MainGraphBuilder # 构建图 builder = MainGraphBuilder() diff --git a/backend/app/__init__.py b/backend/app/__init__.py index 5900015..552fda9 100644 --- a/backend/app/__init__.py +++ b/backend/app/__init__.py @@ -3,6 +3,6 @@ AI Agent 应用模块 """ from .agent.agent_service import AIAgentService -from app.main_graph.tools.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME +from .main_graph.tools.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME __all__ = ["AIAgentService", "AVAILABLE_TOOLS", "TOOLS_BY_NAME"] diff --git a/backend/app/agent/agent_service.py b/backend/app/agent/agent_service.py index f3d7a79..16ff75a 100644 --- a/backend/app/agent/agent_service.py +++ b/backend/app/agent/agent_service.py @@ -7,6 +7,9 @@ import json import asyncio from typing import AsyncGenerator, Dict, Any, Optional, Tuple +# LangGraph 序列化器(修复 checkpoint 反序列化警告) +from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer + # 本地模块 from ..model_services import get_cached_chat_services from ..main_graph.main_graph_builder import build_react_main_graph @@ -18,6 +21,23 @@ from ..logger import debug, info, warning, error from ..main_graph.state import MainGraphState, CurrentAction +# ========== 自定义类型序列化器 ========== +def create_serde() -> JsonPlusSerializer: + """创建带自定义类型注册的序列化器""" + from backend.app.core.intent import ReasoningAction, RetrievalConfig, ReasoningResult + + return JsonPlusSerializer( + allowed_msgpack_modules=[ + ("app.core.intent", "ReasoningAction"), + ("app.core.intent", "RetrievalConfig"), + ("app.core.intent", "ReasoningResult"), + ("app.main_graph.state", "CurrentAction"), + ("app.main_graph.state", "ErrorSeverity"), + ("app.main_graph.state", "ErrorRecord"), + ] + ) + + class AIAgentService: def __init__(self, checkpointer): self.checkpointer = checkpointer @@ -55,6 +75,7 @@ class AIAgentService: tools=self.tools, mem0_client=self.mem0_client ) + # 注意:serde 已在创建 checkpointer 时传入,这里只需传入 checkpointer self.graph = graph_builder.compile(checkpointer=self.checkpointer) info(f"✅ 单图初始化完成") diff --git a/backend/app/agent/history.py b/backend/app/agent/history.py index c2107fe..a87f772 100644 --- a/backend/app/agent/history.py +++ b/backend/app/agent/history.py @@ -4,7 +4,7 @@ """ from typing import List, Dict, Any -from app.logger import error # 保持兼容,或者替换为 logger +from ..logger import error # 保持兼容,或者替换为 logger class ThreadHistoryService: """线程历史查询服务""" diff --git a/backend/app/backend.py b/backend/app/backend.py index 76b7675..c5e23e6 100644 --- a/backend/app/backend.py +++ b/backend/app/backend.py @@ -3,8 +3,13 @@ FastAPI 后端 - 支持动态模型切换,使用 PostgreSQL 持久化记忆 采用依赖注入模式,优雅管理资源生命周期 """ +import warnings +# 抑制 WebSocket 弃用警告(websockets 库升级导致,uvicorn 尚未跟进) +warnings.filterwarnings("ignore", category=DeprecationWarning, module="websockets") +warnings.filterwarnings("ignore", category=DeprecationWarning, module="uvicorn.protocols.websockets") + import os -from app.config import DB_URI, BACKEND_PORT +from ..config import DB_URI, BACKEND_PORT import uuid import json from contextlib import asynccontextmanager @@ -15,26 +20,26 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from pydantic import BaseModel from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver -from .agent.agent_service import AIAgentService +from .agent.agent_service import AIAgentService, create_serde from .agent.history import ThreadHistoryService -from app.core.human_review import ( +from ..core.human_review import ( ReviewManager, InMemoryReviewStore, ReviewStatus, HumanReview ) -from app.subgraphs.contact.api_client import ContactAPIClient -from app.subgraphs.dictionary.api_client import DictionaryAPIClient -from app.subgraphs.news_analysis.api_client import NewsAPIClient +from ..subgraphs.contact.api_client import ContactAPIClient +from ..subgraphs.dictionary.api_client import DictionaryAPIClient +from ..subgraphs.news_analysis.api_client import NewsAPIClient from .db.init_db import init_subgraph_tables from .db.models import ContactRepository, DictionaryRepository, NewsRepository -from app.logger import info, error +from ..logger import info, error @asynccontextmanager async def lifespan(app: FastAPI): """应用生命周期管理:创建并注入全局服务""" # 1. 创建数据库连接池并初始化表(仅 checkpointer) - async with AsyncPostgresSaver.from_conn_string(DB_URI) as checkpointer: + async with AsyncPostgresSaver.from_conn_string(DB_URI, serde=create_serde()) as checkpointer: await checkpointer.setup() # 1.5 初始化子图表 diff --git a/backend/app/core/__init__.py b/backend/app/core/__init__.py index aa03c1b..7e0f5b3 100644 --- a/backend/app/core/__init__.py +++ b/backend/app/core/__init__.py @@ -30,7 +30,7 @@ from .visualization import ( # 为了兼容性,添加 classify_intent 函数 def classify_intent(user_input: str, context: str = None): """兼容旧代码的 classify_intent 函数""" - from app.core.intent_classifier import get_intent_classifier + from backend.app.core.intent_classifier import get_intent_classifier import asyncio classifier = get_intent_classifier() try: diff --git a/backend/app/core/intent.py b/backend/app/core/intent.py index 78d4fa7..ba1b6f7 100644 --- a/backend/app/core/intent.py +++ b/backend/app/core/intent.py @@ -94,7 +94,7 @@ class ReactIntentReasoner: def _get_llm_service(self): """懒加载 LLM 服务(避免循环导入)""" if self._llm_service is None: - from app.model_services.chat_services import get_chat_service, get_small_llm_service + from backend.app.model_services.chat_services import get_chat_service, get_small_llm_service if self._use_small_llm: self._llm_service = get_small_llm_service() else: diff --git a/backend/app/core/web_search.py b/backend/app/core/web_search.py index db8fac2..fe9d79b 100644 --- a/backend/app/core/web_search.py +++ b/backend/app/core/web_search.py @@ -89,7 +89,7 @@ class WebSearchTool: def _search_tavily(self, query: str, max_results: int) -> List[SearchResult]: """使用 Tavily API 搜索""" from tavily import TavilyClient - from app.config import TAVILY_API_KEY, TAVILY_MAX_RESULTS + from backend.app.config import TAVILY_API_KEY, TAVILY_MAX_RESULTS if not TAVILY_API_KEY: raise ValueError("TAVILY_API_KEY 未配置") diff --git a/backend/app/logger.py b/backend/app/logger.py index 2f99040..8b263c6 100644 --- a/backend/app/logger.py +++ b/backend/app/logger.py @@ -4,7 +4,7 @@ """ import os -from app.config import LOG_LEVEL, DEBUG +from .config import LOG_LEVEL, DEBUG import logging from typing import Any # 根据环境变量控制是否显示详细调试信息 diff --git a/backend/app/main_graph/nodes/_utils.py b/backend/app/main_graph/nodes/_utils.py index 930f255..abd0e1b 100644 --- a/backend/app/main_graph/nodes/_utils.py +++ b/backend/app/main_graph/nodes/_utils.py @@ -4,12 +4,13 @@ """ from typing import Dict, Any, Optional +from langchain_core.runnables.config import RunnableConfig async def dispatch_custom_event( event_name: str, data: Dict[str, Any], - config: Optional[Dict[str, Any]] = None, + config: Optional[RunnableConfig] = None, ) -> None: """ 安全地发送自定义事件,忽略发送失败 diff --git a/backend/app/main_graph/nodes/error_handling.py b/backend/app/main_graph/nodes/error_handling.py index eccb611..646afb0 100644 --- a/backend/app/main_graph/nodes/error_handling.py +++ b/backend/app/main_graph/nodes/error_handling.py @@ -2,8 +2,8 @@ 错误处理节点 - 处理子图/工具调用错误 """ -from app.main_graph.state import MainGraphState, ErrorSeverity -from app.logger import info +from ...main_graph.state import MainGraphState, ErrorSeverity +from ...logger import info def error_handling_node(state: MainGraphState) -> MainGraphState: diff --git a/backend/app/main_graph/nodes/fast_paths.py b/backend/app/main_graph/nodes/fast_paths.py index 77139f6..52357f5 100644 --- a/backend/app/main_graph/nodes/fast_paths.py +++ b/backend/app/main_graph/nodes/fast_paths.py @@ -4,6 +4,7 @@ """ from typing import Optional +from langchain_core.runnables.config import RunnableConfig from ..state import MainGraphState from ...logger import info, debug @@ -28,7 +29,7 @@ CHITCHAT_KEYWORDS = { # ========== 闲聊节点 ========== -async def fast_chitchat_node(state: MainGraphState, config: Optional[dict] = None) -> MainGraphState: +async def fast_chitchat_node(state: MainGraphState, config: Optional[RunnableConfig] = None) -> MainGraphState: """快速闲聊节点""" state.current_phase = "fast_chitchat" query = state.user_query or "" @@ -69,14 +70,14 @@ def _match_chitchat_template(query: str) -> str: # ========== 快速 RAG 节点 ========== -async def fast_rag_node(state: MainGraphState, config: Optional[dict] = None) -> MainGraphState: +async def fast_rag_node(state: MainGraphState, config: Optional[RunnableConfig] = None) -> MainGraphState: """快速 RAG 节点:只负责 RAG 检索,然后交给 llm_call 生成回答""" state.current_phase = "fast_rag" query = state.user_query or "" info(f"[Fast RAG] 开始处理: {query[:50]}") # 获取 RAG 工具 - from app.main_graph.utils.rag_initializer import get_rag_tool + from backend.app.main_graph.utils.rag_initializer import get_rag_tool rag_tool = get_rag_tool() info(f"[Fast RAG] 获取到 rag_tool: {rag_tool is not None}") @@ -134,7 +135,7 @@ async def _generate_fast_answer(state: MainGraphState, query: str) -> MainGraphS 请给出简洁、准确的回答:""" # 使用流式输出 - from app.main_graph.config import get_stream_writer + from backend.app.main_graph.config import get_stream_writer writer = get_stream_writer() full_content = "" @@ -164,7 +165,7 @@ async def _generate_fast_answer(state: MainGraphState, query: str) -> MainGraphS # ========== 快速工具节点 ========== -async def fast_tool_node(state: MainGraphState, config: Optional[dict] = None) -> MainGraphState: +async def fast_tool_node(state: MainGraphState, config: Optional[RunnableConfig] = None) -> MainGraphState: """快速工具节点""" state.current_phase = "fast_tool" diff --git a/backend/app/main_graph/nodes/finalize.py b/backend/app/main_graph/nodes/finalize.py index 25882d4..4dc859d 100644 --- a/backend/app/main_graph/nodes/finalize.py +++ b/backend/app/main_graph/nodes/finalize.py @@ -6,9 +6,9 @@ from typing import Any, Dict # 本地模块 -from app.main_graph.state import MainGraphState -from app.utils.logging import log_state_change -from app.logger import info, warning +from ...main_graph.state import MainGraphState +from ...utils.logging import log_state_change +from ...logger import info, warning from langchain_core.runnables.config import RunnableConfig @@ -35,7 +35,7 @@ async def finalize_node(state: MainGraphState, config: RunnableConfig) -> Dict[s try: # 获取流式写入器并发送完成事件 - from app.main_graph.config import get_stream_writer + from backend.app.main_graph.config import get_stream_writer writer = get_stream_writer() # 只在 writer 存在且不是 noop 时才发送 diff --git a/backend/app/main_graph/nodes/hybrid_router.py b/backend/app/main_graph/nodes/hybrid_router.py index 8d8d7b8..729a07c 100644 --- a/backend/app/main_graph/nodes/hybrid_router.py +++ b/backend/app/main_graph/nodes/hybrid_router.py @@ -8,6 +8,7 @@ import json from typing import Optional from dataclasses import dataclass, field from datetime import datetime +from langchain_core.runnables.config import RunnableConfig from ..state import MainGraphState from ...logger import info, debug @@ -157,7 +158,7 @@ def _default_result() -> HybridRouterResult: # ========== 主路由节点 ========== -async def hybrid_router_node(state: MainGraphState, config: Optional[dict] = None) -> MainGraphState: +async def hybrid_router_node(state: MainGraphState, config: Optional[RunnableConfig] = None) -> MainGraphState: """混合路由节点:前置路由,决定走快速路径还是 React 循环""" state.current_phase = "hybrid_router" query = state.user_query or "" diff --git a/backend/app/main_graph/nodes/llm_call.py b/backend/app/main_graph/nodes/llm_call.py index 40c278f..fc0bbdb 100644 --- a/backend/app/main_graph/nodes/llm_call.py +++ b/backend/app/main_graph/nodes/llm_call.py @@ -9,10 +9,10 @@ from langchain_core.language_models import BaseChatModel from langchain_core.messages import AIMessage # 本地模块 -from app.main_graph.state import MainGraphState -from app.agent.prompts import create_system_prompt -from app.utils.logging import log_state_change -from app.logger import debug, info, error +from ...main_graph.state import MainGraphState +from ...agent.prompts import create_system_prompt +from ...utils.logging import log_state_change +from ...logger import debug, info, error def create_dynamic_llm_call_node(chat_services: Dict[str, BaseChatModel], tools: list): @@ -82,6 +82,12 @@ def create_dynamic_llm_call_node(chat_services: Dict[str, BaseChatModel], tools: try: # 添加上下文到消息 messages_with_context = list(state.messages) + info(f"[llm_call] 原始消息数量: {len(messages_with_context)}") + for i, msg in enumerate(messages_with_context): + msg_type = getattr(msg, 'type', 'unknown') + msg_content = getattr(msg, 'content', '')[:100] if hasattr(msg, 'content') else str(msg)[:100] + info(f"[llm_call] msg[{i}] type={msg_type}, content={repr(msg_content)}") + if state.rag_context: from langchain_core.messages import SystemMessage rag_system_msg = SystemMessage(content=f"以下是检索到的相关信息:\n{state.rag_context}") @@ -93,11 +99,13 @@ def create_dynamic_llm_call_node(chat_services: Dict[str, BaseChatModel], tools: break if not inserted: messages_with_context.insert(0, rag_system_msg) - + info(f"[llm_call] RAG上下文已添加,长度: {len(state.rag_context)}") + # 恢复为:手动进行 astream,并将所有的 chunk 拼接成最终的 response 返回。 # LangGraph 会自动监听这期间产生的所有 token。 chain = prompt | llm_with_tools chunks = [] + info(f"[llm_call] 开始调用 LLM astream...") async for chunk in chain.astream( { "messages": messages_with_context, @@ -106,7 +114,26 @@ def create_dynamic_llm_call_node(chat_services: Dict[str, BaseChatModel], tools: config=config ): chunks.append(chunk) - + + info(f"[llm_call] LLM astream 完成,共收到 {len(chunks)} 个 chunks") + for i, chunk in enumerate(chunks[:10]): # 只打印前10个避免日志过多 + chunk_type = type(chunk).__name__ + chunk_content = getattr(chunk, 'content', '') if hasattr(chunk, 'content') else str(chunk) + # 打印更多属性 + additional_kwargs = getattr(chunk, 'additional_kwargs', {}) or {} + response_metadata = getattr(chunk, 'response_metadata', {}) or {} + # 打印所有属性 + info(f"[llm_call] chunk[{i}] type={chunk_type}") + info(f"[llm_call] chunk[{i}] content长度={len(chunk_content) if chunk_content else 0}, content={repr(chunk_content[:200] if chunk_content else '')}") + info(f"[llm_call] chunk[{i}] additional_kwargs={additional_kwargs}") + info(f"[llm_call] chunk[{i}] response_metadata keys={list(response_metadata.keys()) if response_metadata else []}") + info(f"[llm_call] chunk[{i}] response_metadata={response_metadata}") + # 检查是否有其他可能存储内容的属性 + if hasattr(chunk, 'tool_call_chunks'): + info(f"[llm_call] chunk[{i}] tool_call_chunks={chunk.tool_call_chunks}") + if hasattr(chunk, 'usage_metadata'): + info(f"[llm_call] chunk[{i}] usage_metadata={chunk.usage_metadata}") + # 将所有 chunk 合并成最终的 AIMessage if chunks: response = chunks[0] @@ -114,6 +141,7 @@ def create_dynamic_llm_call_node(chat_services: Dict[str, BaseChatModel], tools: response = response + chunk else: response = AIMessage(content="") + info(f"[llm_call] ⚠️ 警告: 没有收到任何 chunks!") elapsed_time = time.time() - start_time diff --git a/backend/app/main_graph/nodes/memory_trigger.py b/backend/app/main_graph/nodes/memory_trigger.py index e848559..3dd0d73 100644 --- a/backend/app/main_graph/nodes/memory_trigger.py +++ b/backend/app/main_graph/nodes/memory_trigger.py @@ -1,8 +1,8 @@ from typing import Any, Dict from langchain_core.runnables.config import RunnableConfig -from app.main_graph.state import MainGraphState -from app.memory.mem0_client import Mem0Client -from app.logger import info +from ...main_graph.state import MainGraphState +from ...memory.mem0_client import Mem0Client +from ...logger import info # 全局变量,在 GraphBuilder 中注入 diff --git a/backend/app/main_graph/nodes/rag_nodes.py b/backend/app/main_graph/nodes/rag_nodes.py index b362a5b..c04b664 100644 --- a/backend/app/main_graph/nodes/rag_nodes.py +++ b/backend/app/main_graph/nodes/rag_nodes.py @@ -7,16 +7,17 @@ import time import asyncio from typing import Optional from datetime import datetime +from langchain_core.runnables.config import RunnableConfig -from app.main_graph.state import MainGraphState, ErrorRecord, ErrorSeverity -from app.main_graph.utils.retry_utils import RAG_RETRY_CONFIG -from app.logger import info +from ...main_graph.state import MainGraphState, ErrorRecord, ErrorSeverity +from ...main_graph.utils.retry_utils import RAG_RETRY_CONFIG +from ...logger import info from ._utils import dispatch_custom_event, make_react_event def _get_rag_tool() -> Optional[callable]: """获取 RAG 工具""" - from app.main_graph.utils.rag_initializer import get_rag_tool + from backend.app.main_graph.utils.rag_initializer import get_rag_tool return get_rag_tool() @@ -35,6 +36,9 @@ async def _rag_retrieve_core(state: MainGraphState, rag_tool: callable) -> MainG # 调用 RAG 工具 rag_context = await rag_tool.ainvoke(retrieval_query) info(f"[RAG Core] 获取到 rag_context: {type(rag_context)}, 长度={len(rag_context) if rag_context else 0}") + info(f"[RAG Core] ========== RAG 返回的知识内容 ==========") + info(f"{rag_context[:500]}..." if len(rag_context) > 500 else rag_context) + info(f"[RAG Core] ========================================") state.rag_context = rag_context state.rag_docs = [{"source": "rag_retrieval", "content": rag_context}] @@ -47,7 +51,7 @@ async def _rag_retrieve_core(state: MainGraphState, rag_tool: callable) -> MainG # ========== RAG 检索节点 ========== -async def rag_retrieve_node(state: MainGraphState, config: Optional[dict] = None) -> MainGraphState: +async def rag_retrieve_node(state: MainGraphState, config: Optional[RunnableConfig] = None) -> MainGraphState: """RAG 检索节点:带超时和重试""" state.current_phase = "rag_retrieving" start_time = time.time() @@ -156,7 +160,7 @@ async def rag_retrieve_node(state: MainGraphState, config: Optional[dict] = None return state -async def rag_re_retrieve_node(state: MainGraphState, config: Optional[dict] = None) -> MainGraphState: +async def rag_re_retrieve_node(state: MainGraphState, config: Optional[RunnableConfig] = None) -> MainGraphState: """重新检索节点""" state.current_phase = "rag_re_retrieving" diff --git a/backend/app/main_graph/nodes/reasoning.py b/backend/app/main_graph/nodes/reasoning.py index 73b50c9..09ef0a3 100644 --- a/backend/app/main_graph/nodes/reasoning.py +++ b/backend/app/main_graph/nodes/reasoning.py @@ -3,16 +3,17 @@ React 推理节点 使用 intent.py 进行意图推理 """ -from typing import Dict, Any, Optional +from typing import Optional from datetime import datetime +from langchain_core.runnables.config import RunnableConfig -from app.core.intent import react_reason_async, ReasoningResult -from app.main_graph.state import MainGraphState -from app.logger import info +from ...core.intent import react_reason_async, ReasoningResult +from ...main_graph.state import MainGraphState +from ...logger import info from ._utils import dispatch_custom_event, make_react_event -async def react_reason_node(state: MainGraphState, config: Optional[Dict[str, Any]] = None) -> MainGraphState: +async def react_reason_node(state: MainGraphState, config: Optional[RunnableConfig] = None) -> MainGraphState: """React 模式推理节点:判断下一步做什么""" state.current_phase = "react_reasoning" state.reasoning_step += 1 diff --git a/backend/app/main_graph/nodes/retrieve_memory.py b/backend/app/main_graph/nodes/retrieve_memory.py index ed8b4fa..1a48855 100644 --- a/backend/app/main_graph/nodes/retrieve_memory.py +++ b/backend/app/main_graph/nodes/retrieve_memory.py @@ -6,10 +6,10 @@ from typing import Any, Dict # 本地模块 -from app.main_graph.state import MainGraphState -from app.memory.mem0_client import Mem0Client -from app.utils.logging import log_state_change -from app.logger import debug +from ...main_graph.state import MainGraphState +from ...memory.mem0_client import Mem0Client +from ...utils.logging import log_state_change +from ...logger import debug def create_retrieve_memory_node(mem0_client: Mem0Client): @@ -67,10 +67,10 @@ def create_retrieve_memory_node(mem0_client: Mem0Client): else: debug("🔍 [记忆检索] 未找到相关记忆") except Exception as e: - from app.logger import warning + from backend.app.logger import warning warning(f"⚠️ Mem0 检索失败: {e}") else: - from app.logger import warning + from backend.app.logger import warning warning("⚠️ Mem0 未初始化,跳过记忆检索") memory_context = "\n\n".join(memory_text_parts) if memory_text_parts else "暂无用户信息" diff --git a/backend/app/main_graph/nodes/routing.py b/backend/app/main_graph/nodes/routing.py index 24985c4..d9b62dd 100644 --- a/backend/app/main_graph/nodes/routing.py +++ b/backend/app/main_graph/nodes/routing.py @@ -10,9 +10,9 @@ from datetime import datetime -from app.core.intent import get_route_by_reasoning, ReasoningAction -from app.main_graph.state import MainGraphState -from app.logger import info +from ...core.intent import get_route_by_reasoning, ReasoningAction +from ...main_graph.state import MainGraphState +from ...logger import info # ========== 初始化状态节点 ========== diff --git a/backend/app/main_graph/nodes/summarize.py b/backend/app/main_graph/nodes/summarize.py index d75ba1d..1969d10 100644 --- a/backend/app/main_graph/nodes/summarize.py +++ b/backend/app/main_graph/nodes/summarize.py @@ -6,10 +6,10 @@ from typing import Any, Dict # 本地模块 -from app.main_graph.state import MainGraphState -from app.memory.mem0_client import Mem0Client -from app.utils.logging import log_state_change -from app.logger import debug, info, error, warning +from ...main_graph.state import MainGraphState +from ...memory.mem0_client import Mem0Client +from ...utils.logging import log_state_change +from ...logger import debug, info, error, warning def create_summarize_node(mem0_client: Mem0Client): diff --git a/backend/app/main_graph/nodes/tool_call.py b/backend/app/main_graph/nodes/tool_call.py index b871865..315f119 100644 --- a/backend/app/main_graph/nodes/tool_call.py +++ b/backend/app/main_graph/nodes/tool_call.py @@ -6,12 +6,12 @@ import asyncio from typing import Any, Dict from langchain_core.messages import AIMessage, ToolMessage -from app.main_graph.config import get_stream_writer +from ...main_graph.config import get_stream_writer # 本地模块 -from app.main_graph.state import MainGraphState -from app.utils.logging import log_state_change -from app.logger import debug, info +from ...main_graph.state import MainGraphState +from ...utils.logging import log_state_change +from ...logger import debug, info def create_tool_call_node(tools_by_name: Dict[str, Any]): """ diff --git a/backend/app/main_graph/nodes/web_search.py b/backend/app/main_graph/nodes/web_search.py index 71d0681..f9aa62e 100644 --- a/backend/app/main_graph/nodes/web_search.py +++ b/backend/app/main_graph/nodes/web_search.py @@ -2,14 +2,15 @@ 联网搜索节点 - 执行搜索并将结果保存到状态 """ -from typing import Dict, Any, Optional +from typing import Optional from datetime import datetime +from langchain_core.runnables.config import RunnableConfig -from app.main_graph.state import MainGraphState, ErrorRecord, ErrorSeverity -from app.logger import info +from ...main_graph.state import MainGraphState, ErrorRecord, ErrorSeverity +from ...logger import info -async def web_search_node(state: MainGraphState, config: Optional[Dict[str, Any]] = None) -> MainGraphState: +async def web_search_node(state: MainGraphState, config: Optional[RunnableConfig] = None) -> MainGraphState: """ 联网搜索节点:执行搜索并将结果保存到状态 """ @@ -39,7 +40,7 @@ async def web_search_node(state: MainGraphState, config: Optional[Dict[str, Any] search_query = reasoning_result.metadata.get("search_query", state.user_query) if reasoning_result else state.user_query try: - from app.core import web_search + from backend.app.core import web_search print(f"[WebSearch] 搜索: {search_query}") search_result = web_search(search_query, max_results=5) diff --git a/backend/app/main_graph/subgraph_wrapper.py b/backend/app/main_graph/subgraph_wrapper.py index 0be9b4c..f01cec1 100644 --- a/backend/app/main_graph/subgraph_wrapper.py +++ b/backend/app/main_graph/subgraph_wrapper.py @@ -5,6 +5,8 @@ from typing import Dict, Any, Optional from datetime import datetime +from langchain_core.runnables.config import RunnableConfig + from .state import MainGraphState, ErrorRecord, ErrorSeverity from ..logger import info @@ -19,7 +21,7 @@ def wrap_subgraph_for_error_handling(subgraph, name: str): Returns: 包装后的节点函数 """ - async def wrapped_node(state: MainGraphState, config: Optional[Dict[str, Any]] = None) -> MainGraphState: + async def wrapped_node(state: MainGraphState, config: Optional[RunnableConfig] = None) -> MainGraphState: # 发送子图开始事件 if config: try: diff --git a/backend/app/main_graph/tools/common_tools.py b/backend/app/main_graph/tools/common_tools.py index a52c908..e62e6e7 100644 --- a/backend/app/main_graph/tools/common_tools.py +++ b/backend/app/main_graph/tools/common_tools.py @@ -20,7 +20,7 @@ def web_search_tool(query: str, max_results: int = 5) -> str: 格式化的搜索结果,包含引用溯源 """ try: - from app.core import web_search + from backend.app.core import web_search return web_search(query, max_results) except Exception as e: return f"联网搜索出错:{str(e)}" @@ -40,7 +40,7 @@ def generate_chart_tool(data_text: str, chart_type: str = "bar") -> str: 格式化的图表输出(Mermaid 格式) """ try: - from app.core import generate_chart + from backend.app.core import generate_chart return generate_chart(data_text, chart_type) except Exception as e: return f"生成图表出错:{str(e)}\n\n请使用格式:标题,标签1:值1,标签2:值2,..." diff --git a/backend/app/main_graph/tools/subgraph_tools.py b/backend/app/main_graph/tools/subgraph_tools.py index e29fbf8..51fc3de 100644 --- a/backend/app/main_graph/tools/subgraph_tools.py +++ b/backend/app/main_graph/tools/subgraph_tools.py @@ -22,13 +22,13 @@ def dictionary_tool(query: str, action: Optional[str] = None) -> str: 格式化的结果文本 """ try: - from app.subgraphs.dictionary import ( + from backend.app.subgraphs.dictionary import ( DictionaryState, DictionaryAction, parse_intent, format_result ) - from app.subgraphs.dictionary.nodes import ( + from backend.app.subgraphs.dictionary.nodes import ( query_word, translate_text, extract_terms, get_daily_word ) @@ -87,13 +87,13 @@ def news_analysis_tool(query: str, action: Optional[str] = None) -> str: 格式化的结果文本 """ try: - from app.subgraphs.news_analysis import ( + from backend.app.subgraphs.news_analysis import ( NewsAnalysisState, NewsAction, parse_intent, format_result ) - from app.subgraphs.news_analysis.nodes import ( + from backend.app.subgraphs.news_analysis.nodes import ( query_news, analyze_url, extract_keywords, generate_report ) @@ -150,13 +150,13 @@ def contact_tool(query: str, action: Optional[str] = None) -> str: 格式化的结果文本 """ try: - from app.subgraphs.contact import ( + from backend.app.subgraphs.contact import ( ContactState, ContactAction, parse_intent, format_result ) - from app.subgraphs.contact.nodes import ( + from backend.app.subgraphs.contact.nodes import ( query_contact, add_contact, list_contacts ) diff --git a/backend/app/main_graph/utils/rag_initializer.py b/backend/app/main_graph/utils/rag_initializer.py index becd5e7..62c8d7b 100644 --- a/backend/app/main_graph/utils/rag_initializer.py +++ b/backend/app/main_graph/utils/rag_initializer.py @@ -1,8 +1,8 @@ # app/rag_initializer.py -from app.rag.tools import create_rag_tool -from app.rag.retriever import create_parent_hybrid_retriever -from app.model_services import get_embedding_service -from app.logger import info, warning +from ...rag.tools import create_rag_tool +from ...rag.retriever import create_parent_hybrid_retriever +from ...model_services import get_embedding_service +from ...logger import info, warning import sys # 全局 RAG 工具 @@ -38,7 +38,7 @@ async def init_rag_tool(force: bool = False): return _rag_tool try: - from app.model_services.chat_services import get_chat_service + from backend.app.model_services.chat_services import get_chat_service info("🔄 正在初始化 RAG 检索系统...") embeddings = get_embedding_service() diff --git a/backend/app/main_graph/utils/retry_utils.py b/backend/app/main_graph/utils/retry_utils.py index 1b77844..69c44a5 100644 --- a/backend/app/main_graph/utils/retry_utils.py +++ b/backend/app/main_graph/utils/retry_utils.py @@ -287,7 +287,7 @@ def create_retry_wrapper_for_node( time.sleep(delay) # 所有重试都失败,更新状态错误信息 - from app.main_graph.state import ErrorRecord, ErrorSeverity + from backend.app.main_graph.state import ErrorRecord, ErrorSeverity error_record = ErrorRecord( error_type=f"{node_name}TimeoutError", diff --git a/backend/app/memory/mem0_client.py b/backend/app/memory/mem0_client.py index 1c53bf7..948be98 100644 --- a/backend/app/memory/mem0_client.py +++ b/backend/app/memory/mem0_client.py @@ -9,7 +9,7 @@ from typing import Optional, List from mem0 import AsyncMemory -from app.config import ( +from ..config import ( LLM_API_KEY, ZHIPUAI_API_KEY, VLLM_BASE_URL, @@ -21,9 +21,9 @@ from app.config import ( ZHIPU_EMBEDDING_MODEL, ZHIPU_API_BASE, ) -from app.logger import info, warning, error -from app.model_services import get_embedding_service -from app.model_services.chat_services import get_chat_service +from ..logger import info, warning, error +from ..model_services import get_embedding_service +from ..model_services.chat_services import get_chat_service class Mem0Client: @@ -48,7 +48,7 @@ class Mem0Client: info(f"✅ 嵌入服务可用,向量维度: {embedding_dim}") # 构建 embedder 配置 - from app.model_services.embedding_services import ( + from backend.app.model_services.embedding_services import ( LocalLlamaCppEmbeddingProvider, ZhipuEmbeddingProvider, ) diff --git a/backend/app/model_services/chat_services.py b/backend/app/model_services/chat_services.py index 581a2c3..4843ce8 100644 --- a/backend/app/model_services/chat_services.py +++ b/backend/app/model_services/chat_services.py @@ -23,7 +23,7 @@ from .base import ( FallbackServiceChain, SingletonServiceManager ) -from app.config import ( +from ..config import ( VLLM_BASE_URL, LLM_API_KEY, ZHIPUAI_API_KEY, @@ -203,7 +203,7 @@ class LocalSmallModelProvider(BaseServiceProvider[BaseChatModel]): """ def __init__(self, model: str = None): - from app.config import SMALL_LOCAL_MODEL_NAME, SMALL_VLLM_BASE_URL, SMALL_LLM_API_KEY + from ..config import SMALL_LOCAL_MODEL_NAME, SMALL_VLLM_BASE_URL, SMALL_LLM_API_KEY super().__init__("local_small") self._model = model or SMALL_LOCAL_MODEL_NAME self._base_url = SMALL_VLLM_BASE_URL @@ -242,7 +242,7 @@ class DeepSeekSmallModelProvider(BaseServiceProvider[BaseChatModel]): """ def __init__(self, model: str = None): - from app.config import SMALL_DEEPSEEK_MODEL, SMALL_DEEPSEEK_API_KEY, SMALL_DEEPSEEK_API_BASE + from ..config import SMALL_DEEPSEEK_MODEL, SMALL_DEEPSEEK_API_KEY, SMALL_DEEPSEEK_API_BASE super().__init__("deepseek_small") self._model = model or SMALL_DEEPSEEK_MODEL self._api_key = SMALL_DEEPSEEK_API_KEY diff --git a/backend/app/model_services/embedding_services.py b/backend/app/model_services/embedding_services.py index 532c2b2..f371571 100644 --- a/backend/app/model_services/embedding_services.py +++ b/backend/app/model_services/embedding_services.py @@ -21,7 +21,7 @@ from .base import ( FallbackServiceChain, SingletonServiceManager ) -from app.config import ( +from ..config import ( LLAMACPP_EMBEDDING_URL, LLAMACPP_API_KEY, ZHIPUAI_API_KEY, diff --git a/backend/app/model_services/rerank_services.py b/backend/app/model_services/rerank_services.py index 1aabe00..475995b 100644 --- a/backend/app/model_services/rerank_services.py +++ b/backend/app/model_services/rerank_services.py @@ -27,7 +27,7 @@ from .base import ( FallbackServiceChain, SingletonServiceManager ) -from app.config import ( +from ..config import ( LLAMACPP_RERANKER_URL, LLAMACPP_API_KEY, ZHIPUAI_API_KEY, @@ -92,20 +92,33 @@ class LocalLlamaCppRerankService(BaseRerankService): "documents": documents, } + logger.info(f"[LocalLlamaCppRerank] 调用 rerank API: {base}/rerank") + logger.info(f"[LocalLlamaCppRerank] 请求 payload: query={query[:50]}, documents数量={len(documents)}") + with httpx.Client(timeout=120) as client: response = client.post( f"{base}/rerank", headers=headers, json=payload, ) - response.raise_for_status() + logger.info(f"[LocalLlamaCppRerank] 响应状态码: {response.status_code}") + + if response.status_code != 200: + logger.error(f"[LocalLlamaCppRerank] 请求失败: {response.status_code}") + logger.error(f"[LocalLlamaCppRerank] 响应内容: {response.text[:500]}") + response.raise_for_status() + data = response.json() + logger.info(f"[LocalLlamaCppRerank] 响应数据类型: {type(data)}") if isinstance(data, dict) and "results" in data: results = data["results"] results_sorted = sorted(results, key=lambda x: x["index"]) - return [item["relevance_score"] for item in results_sorted] + scores = [item["relevance_score"] for item in results_sorted] + logger.info(f"[LocalLlamaCppRerank] 返回 {len(scores)} 个得分") + return scores else: + logger.error(f"[LocalLlamaCppRerank] 未知响应格式: {type(data)}") raise ValueError(f"未知的 rerank API 响应格式: {data}") @@ -207,11 +220,13 @@ class LLMFallbackRerankService(BaseRerankService): if not documents: return [] + logger.info(f"[LLMFallbackRerank] 开始为 {len(documents)} 个文档打分") scores = [] - for doc in documents: + for i, doc in enumerate(documents): score = self._score_single_document(query, doc) scores.append(score) - + logger.info(f"[LLMFallbackRerank] doc[{i}] score={score:.4f}") + return scores def _score_single_document(self, query: str, document: str) -> float: diff --git a/backend/app/rag/__init__.py b/backend/app/rag/__init__.py index 04e9462..95a8d56 100644 --- a/backend/app/rag/__init__.py +++ b/backend/app/rag/__init__.py @@ -13,7 +13,7 @@ RAG 检索与生成模块 用户查询 → 多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档 示例用法: - >>> from app.rag.rag import RAGPipeline, create_rag_tool + >>> from backend.app.rag.rag import RAGPipeline, create_rag_tool >>> from rag_indexer.builder import IndexBuilder, IndexBuilderConfig >>> from langchain_openai import ChatOpenAI >>> diff --git a/backend/app/rag/pipeline.py b/backend/app/rag/pipeline.py index 4853000..8fd67d2 100644 --- a/backend/app/rag/pipeline.py +++ b/backend/app/rag/pipeline.py @@ -1,137 +1,114 @@ """ -RAG 检索流水线模块 - -提供固定流程的 RAG 检索: -多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档 - -默认使用混合检索(稠密+稀疏)+ 父子文档模式。 +RAG 检索流水线 +流程: 检索子文档 → 重排 → 获取父文档 → 返回 """ import asyncio -import os -from typing import List, Optional +import logging +from typing import List from langchain_core.documents import Document from langchain_core.language_models import BaseLanguageModel -from app.model_services import get_rerank_service, get_small_llm_service -from app.rag.rerank import create_document_reranker -from app.rag.query_transform import MultiQueryGenerator -from app.rag.fusion import reciprocal_rank_fusion -from app.rag.retriever import create_parent_hybrid_retriever +from ..model_services import get_rerank_service, get_small_llm_service +from ..rag.rerank import create_document_reranker +from ..rag.query_transform import MultiQueryGenerator +from ..rag.fusion import reciprocal_rank_fusion +from ..rag.retriever import create_parent_hybrid_retriever + +logger = logging.getLogger(__name__) class RAGPipeline: - """ - 固定流程的 RAG 检索流水线: - 多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档 - - 默认使用混合检索(稠密+BM25稀疏)+ 父子文档模式。 - """ - def __init__( self, retriever=None, - llm: Optional[BaseLanguageModel] = "default_small", + llm: BaseLanguageModel | str = "default_small", num_queries: int = 3, rerank_top_n: int = 5, collection_name: str = "rag_documents", + use_rerank: bool = True, + return_parent_docs: bool = True, ): - """ - Args: - retriever: 基础检索器对象,需实现 ainvoke(query) 异步方法。 - 如果不提供,会自动创建默认的父子文档混合检索器。 - llm: 用于生成多路查询的语言模型。 - - "default_small": (默认) 使用小模型(本地 + DeepSeek) - - None / False: 不做查询改写 - - BaseLanguageModel 实例: 自定义模型 - num_queries: 生成的查询变体数量。 - rerank_top_n: 最终返回的文档数量。 - collection_name: Qdrant 集合名称(仅当 retriever 未提供时使用)。 - """ - # 如果没有提供 retriever,自动创建默认的混合检索器 - if retriever is None: - self.retriever = create_parent_hybrid_retriever( - collection_name=collection_name, - search_k=rerank_top_n * 2 # 多取一些给重排序用 - ) - else: - self.retriever = retriever - - # 处理 llm 参数 + self.retriever = retriever or create_parent_hybrid_retriever( + collection_name=collection_name, search_k=rerank_top_n * 4 + ) + self.num_queries = num_queries + self.rerank_top_n = rerank_top_n + self.use_rerank = use_rerank + self.return_parent_docs = return_parent_docs + if llm == "default_small": try: self.llm = get_small_llm_service() - except Exception as e: - import logging - logger = logging.getLogger(__name__) - logger.warning(f"小模型初始化失败,将不做查询改写: {e}") + except Exception: self.llm = None - elif llm in (None, False): - self.llm = None else: - self.llm = llm - - self.num_queries = num_queries - self.rerank_top_n = rerank_top_n - - # 初始化组件 - 使用统一的重排服务获取接口 - self.query_generator = MultiQueryGenerator(llm=self.llm, num_queries=num_queries) if self.llm else None - self.reranker = create_document_reranker() - + self.llm = llm if llm else None + + self.query_generator = MultiQueryGenerator(self.llm, num_queries) if self.llm else None + self.reranker = create_document_reranker() if use_rerank else None + logger.info(f"[Pipeline] init: rerank={use_rerank}, return_parent={return_parent_docs}") + async def aretrieve(self, query: str) -> List[Document]: - """ - 异步执行完整检索流程 - - Args: - query: 用户查询 - - Returns: - 检索到的相关文档列表 - """ - # 如果有 query_generator,做多路改写 - if self.query_generator and self.llm: - # Step 1: 生成多路查询 + # Step 1: 检索 + child_docs = await self._retrieve(query) + logger.info(f"[Pipeline] 检索到 {len(child_docs)} 个子文档") + # 调试:打印子文档长度 + for i, doc in enumerate(child_docs[:5]): + content_len = len(doc.page_content) + logger.info(f"[Pipeline] 子文档[{i}] 长度={content_len}字符") + + # Step 2: 重排 + if self.reranker: + try: + child_docs = self.reranker.compress_documents(child_docs, query, self.rerank_top_n) + logger.info(f"[Pipeline] 重排后 {len(child_docs)} 个") + except Exception as e: + logger.warning(f"[Pipeline] 重排失败: {e}") + child_docs = child_docs[:self.rerank_top_n] + + # Step 3: 获取父文档 + if self.return_parent_docs: + return await self._get_parents(child_docs) + return child_docs + + async def _retrieve(self, query: str) -> List[Document]: + if self.query_generator: queries = await self.query_generator.agenerate(query) - # 包含原始查询,确保至少有一条 - if query not in queries: - queries.insert(0, query) - else: - # 如果原始查询已在列表中,将其移至首位 - queries.remove(query) - queries.insert(0, query) - - # Step 2: 并行检索(每个查询获取文档列表) - tasks = [self.retriever.ainvoke(q) for q in queries] - doc_lists = await asyncio.gather(*tasks) - - # Step 3: RRF 融合 - fused_docs = reciprocal_rank_fusion(doc_lists) - else: - # 没有 LLM 做查询改写,直接用原始查询检索 - fused_docs = await self.retriever.ainvoke(query) - - # Step 4: 重排序 + queries = [query] + [q for q in queries if q != query] + doc_lists = await asyncio.gather(*[self.retriever.ainvoke(q) for q in queries]) + return reciprocal_rank_fusion(doc_lists) + return await self.retriever.ainvoke(query) + + async def _get_parents(self, child_docs: List[Document]) -> List[Document]: + parent_map = {} + for doc in child_docs: + pid = doc.metadata.get("parent_id") + if pid and pid not in parent_map: + parent_map[pid] = doc.metadata.get("score", 0.0) + + if not parent_map: + logger.warning("[Pipeline] 未找到 parent_id,返回子文档") + return child_docs + try: - final_docs = self.reranker.compress_documents(fused_docs, query, top_n=self.rerank_top_n) - except Exception: - # 若重排序器不可用,直接返回融合后的前 N 个结果 - final_docs = fused_docs[:self.rerank_top_n] - - return final_docs + from backend.rag_core import create_docstore + docstore, _ = create_docstore() + # 同步获取(异步版本不存在) + parent_docs = docstore.mget(list(parent_map.keys())) + parent_map2 = {d.metadata.get("id"): d for d in parent_docs if d} + result = [(parent_map2[pid], score) for pid, score in parent_map.items() if pid in parent_map2] + result.sort(key=lambda x: x[1], reverse=True) + docs = [d for d, _ in result] + logger.info(f"[Pipeline] 获取到 {len(docs)} 个父文档") + return docs + except Exception as e: + logger.warning(f"[Pipeline] 获取父文档失败: {e}") + return child_docs def format_context(self, documents: List[Document]) -> str: - """ - 将文档列表格式化为上下文字符串 - - Args: - documents: 文档列表 - - Returns: - 格式化后的上下文字符串 - """ if not documents: return "" - parts = [] for i, doc in enumerate(documents, 1): source = doc.metadata.get("source", "未知来源") @@ -139,30 +116,5 @@ class RAGPipeline: return "\n".join(parts) -def create_rag_pipeline( - collection_name: str = "rag_documents", - llm: Optional[BaseLanguageModel] = "default_small", - num_queries: int = 3, - rerank_top_n: int = 5, -) -> RAGPipeline: - """ - 创建 RAG 检索流水线的便捷函数 - - Args: - collection_name: Qdrant 集合名称 - llm: 用于生成多路查询的语言模型。 - - "default_small": (默认) 使用小模型(本地 + DeepSeek) - - None / False: 不做查询改写 - - BaseLanguageModel 实例: 自定义模型 - num_queries: 生成的查询变体数量 - rerank_top_n: 最终返回的文档数量 - - Returns: - RAGPipeline 实例 - """ - return RAGPipeline( - llm=llm, - num_queries=num_queries, - rerank_top_n=rerank_top_n, - collection_name=collection_name - ) +def create_rag_pipeline(**kwargs) -> RAGPipeline: + return RAGPipeline(**kwargs) diff --git a/backend/app/rag/rerank.py b/backend/app/rag/rerank.py index 34a7133..d63d303 100644 --- a/backend/app/rag/rerank.py +++ b/backend/app/rag/rerank.py @@ -57,14 +57,26 @@ class DocumentReranker: try: # 1. 从 Document 提取内容(业务逻辑) doc_contents = [doc.page_content for doc in documents] + logger.info(f"[Rerank] 收到 {len(documents)} 个文档待重排, query={query[:50]}") + total_chars = sum(len(c) for c in doc_contents) + logger.info(f"[Rerank] 各文档长度: {[len(c) for c in doc_contents]}, 总字符数: {total_chars}") + # 粗略估算 tokens (中文约 0.75 tokens/字符) + estimated_tokens = int(total_chars * 0.75) + logger.info(f"[Rerank] 估算总 tokens: ~{estimated_tokens} (假设中文)") # 2. 调用纯服务层计算得分 + logger.info(f"[Rerank] 正在调用 rerank service: {type(self._rerank_service).__name__}") scores = self._rerank_service.compute_scores(query, doc_contents) + logger.info(f"[Rerank] 获取到 {len(scores)} 个得分: {scores}") # 3. 根据得分排序(业务逻辑) doc_score_pairs = list(zip(documents, scores)) doc_score_pairs_sorted = sorted(doc_score_pairs, key=lambda x: x[1], reverse=True) + logger.info(f"[Rerank] 排序后的结果:") + for i, (doc, score) in enumerate(doc_score_pairs_sorted): + logger.info(f" [{i}] score={score:.4f}, content={doc.page_content[:80]}...") + # 4. 取 top_n top_docs = [pair[0] for pair in doc_score_pairs_sorted[:top_n]] @@ -72,6 +84,9 @@ class DocumentReranker: except Exception as e: logger.warning(f"重排过程出错,返回原始前 {top_n} 个结果: {e}") + logger.warning(f"[Rerank] 异常详情: {type(e).__name__}: {e}") + import traceback + logger.warning(f"[Rerank] 堆栈: {traceback.format_exc()}") return documents[:top_n] diff --git a/backend/app/rag/retriever.py b/backend/app/rag/retriever.py index a288970..5f70b4c 100644 --- a/backend/app/rag/retriever.py +++ b/backend/app/rag/retriever.py @@ -19,10 +19,10 @@ from langchain_core.embeddings import Embeddings from langchain_core.retrievers import BaseRetriever from pydantic import Field, PrivateAttr -from rag_core import QdrantHybridStore, get_sparse_embedder, create_docstore -from rag_core.client import create_async_qdrant_client -from app.model_services import get_embedding_service -from app.logger import info, warning, debug +from backend.rag_core import QdrantHybridStore, get_sparse_embedder, create_docstore +from backend.rag_core.client import create_async_qdrant_client +from ..model_services import get_embedding_service +from ..logger import info, warning, debug # 模块级常量 @@ -131,20 +131,20 @@ class HybridRetriever(BaseRetriever): class ParentHybridRetriever(BaseRetriever): """ 父子文档混合检索器(异步): - + 1. 先用混合检索找到相关子文档 2. 根据子文档的 parent_id 找到对应的父文档 3. 去重并返回父文档 """ - + collection_name: str = Field(description="Qdrant 集合名称") search_k: int = Field(default=DEFAULT_PARENT_SEARCH_K, description="检索返回结果数") - + _vector_store: Any = PrivateAttr() _client: Any = PrivateAttr() _sparse_embedder: Any = PrivateAttr() _docstore: Any = PrivateAttr() - + def __init__( self, collection_name: str, @@ -188,7 +188,7 @@ class ParentHybridRetriever(BaseRetriever): self, query: str, *, run_manager: Any = None ) -> List[Document]: """ - 异步检索相关父文档 + 异步检索相关子文档 """ # 1. 生成查询向量 dense_query = await self._vector_store.aembed_query(query) @@ -197,10 +197,10 @@ class ParentHybridRetriever(BaseRetriever): indices=sparse_query["indices"], values=sparse_query["values"] ) - + # 2. 多取一些子文档,避免去重后数量不足 search_limit = self.search_k * 2 - + # 3. 使用 query_points API 进行混合检索 response = await self._client.query_points( collection_name=self.collection_name, @@ -220,87 +220,27 @@ class ParentHybridRetriever(BaseRetriever): limit=search_limit, with_payload=True ) - + if not response.points: debug("混合检索未找到任何文档") return [] - - # 4. 收集 parent_id 和对应最高得分 - parent_score_map = {} - parent_ids = set() - child_point_map = {} # 保存子文档点用于降级 - + + # 4. 构建子文档列表 + child_docs = [] for point in response.points: payload_copy = point.payload.copy() - parent_id = payload_copy.get("parent_id", point.id) - score = point.score - - if parent_id not in parent_score_map or score > parent_score_map[parent_id]: - parent_score_map[parent_id] = score - parent_ids.add(parent_id) - child_point_map[parent_id] = point - - # 5. 批量查询父文档 - parent_docs = [] - found_parent_ids = set() - - # 先尝试从 Qdrant 直接查询(如果父文档也在 Qdrant 中) - try: - parent_points = await self._client.retrieve( - collection_name=self.collection_name, - ids=list(parent_ids), - with_payload=True + doc = Document( + page_content=payload_copy.pop("page_content", payload_copy.pop("text", "")), + metadata={ + **payload_copy, + "child_id": point.id, + "score": point.score + } ) - - for point in parent_points: - payload_copy = point.payload.copy() - doc = Document( - page_content=payload_copy.pop("page_content", payload_copy.pop("text", "")), - metadata=payload_copy - ) - parent_docs.append(doc) - found_parent_ids.add(point.id) - - except Exception as e: - warning(f"从 Qdrant 查询父文档失败: {e}") - - # 6. 如果有 docstore,尝试从 docstore 查询剩余的父文档 - if self._docstore and len(found_parent_ids) < len(parent_ids): - missing_parent_ids = parent_ids - found_parent_ids - try: - docstore_docs = await self._docstore.amget(missing_parent_ids) - for doc_id, doc in zip(missing_parent_ids, docstore_docs): - if doc is not None: - parent_docs.append(doc) - found_parent_ids.add(doc_id) - except Exception as e: - warning(f"从 docstore 查询父文档失败: {e}") - - # 7. 降级:对于仍未找到的父文档,用子文档本身代替 - missing_parent_ids = parent_ids - found_parent_ids - if missing_parent_ids: - warning(f"以下 parent_id 未找到对应的父文档,将返回子文档本身: {missing_parent_ids}") - for parent_id in missing_parent_ids: - child_point = child_point_map.get(parent_id) - if child_point: - payload_copy = child_point.payload.copy() - doc = Document( - page_content=payload_copy.pop("page_content", payload_copy.pop("text", "")), - metadata=payload_copy - ) - parent_docs.append(doc) - - # 8. 按照得分降序排序,返回前 k 个 - parent_docs_with_scores = [ - (doc, parent_score_map.get(doc.metadata.get("id", doc.id if hasattr(doc, "id") else ""), 0.0)) - for doc in parent_docs - ] - parent_docs_with_scores.sort(key=lambda x: x[1], reverse=True) - - final_docs = [doc for doc, _ in parent_docs_with_scores[:self.search_k]] - debug(f"父子文档混合检索返回 {len(final_docs)} 个父文档") - - return final_docs + child_docs.append(doc) + + debug(f"父子文档混合检索返回 {len(child_docs)} 个子文档") + return child_docs def create_hybrid_retriever( diff --git a/backend/app/rag/tools.py b/backend/app/rag/tools.py index 9a069ad..d66e688 100644 --- a/backend/app/rag/tools.py +++ b/backend/app/rag/tools.py @@ -10,7 +10,7 @@ from typing import Callable, Optional from langchain_core.tools import tool from langchain_core.language_models import BaseLanguageModel from langchain_core.retrievers import BaseRetriever -from app.rag.pipeline import RAGPipeline, create_rag_pipeline +from ..rag.pipeline import RAGPipeline, create_rag_pipeline def create_rag_tool( diff --git a/backend/app/subgraphs/contact/api_client.py b/backend/app/subgraphs/contact/api_client.py index 21187f7..d452123 100644 --- a/backend/app/subgraphs/contact/api_client.py +++ b/backend/app/subgraphs/contact/api_client.py @@ -20,22 +20,18 @@ class ContactAPIClient: def __init__(self, conn=None): """ - 初始化 - + 初始化(使用延迟初始化,MCP在首次使用时初始化) + Args: conn: 数据库连接(保留用于向后兼容) """ self.conn = conn - - # 确保MCP已初始化 - import asyncio - try: - asyncio.create_task(self._init_mcp()) - except RuntimeError: - pass # 没有事件循环时跳过,延迟初始化 + self._mcp_initialized = False # 延迟初始化标志 async def _init_mcp(self): - """初始化MCP系统""" + """初始化MCP系统(延迟初始化)""" + if self._mcp_initialized: + return # 已初始化,跳过 if not mcp_manager.get_adapter("contact"): # 获取repository(如果有) repo = None @@ -45,9 +41,10 @@ class ContactAPIClient: repo = ContactRepository(self.conn) except Exception: pass - + mcp_manager.register_adapter(ContactAdapter(contact_repo=repo)) await mcp_manager.initialize() + self._mcp_initialized = True async def list_contacts(self, user_id: str = "default") -> List[Contact]: """获取联系人列表""" diff --git a/backend/app/subgraphs/contact/nodes.py b/backend/app/subgraphs/contact/nodes.py index 66519c6..19c7692 100644 --- a/backend/app/subgraphs/contact/nodes.py +++ b/backend/app/subgraphs/contact/nodes.py @@ -8,7 +8,7 @@ from typing import Dict, Any from datetime import datetime # 公共工具 -from app.core import MarkdownFormatter +from ...core import MarkdownFormatter from .state import ContactState from .api_client import ContactAPIClient diff --git a/backend/app/subgraphs/dictionary/api_client.py b/backend/app/subgraphs/dictionary/api_client.py index a9a403c..0ec009a 100644 --- a/backend/app/subgraphs/dictionary/api_client.py +++ b/backend/app/subgraphs/dictionary/api_client.py @@ -22,20 +22,19 @@ class DictionaryAPIClient: word_repository: Optional[Any] = None def __post_init__(self): - """初始化后设置MCP""" - import asyncio - try: - asyncio.create_task(self._init_mcp()) - except RuntimeError: - pass + """初始化后设置(延迟初始化标志)""" + self._mcp_initialized = False # 延迟初始化标志 async def _init_mcp(self): - """初始化MCP系统""" + """初始化MCP系统(延迟初始化)""" + if self._mcp_initialized: + return # 已初始化,跳过 if not mcp_manager.get_adapter("dictionary"): mcp_manager.register_adapter( DictionaryAdapter(word_repo=self.word_repository) ) await mcp_manager.initialize() + self._mcp_initialized = True async def query_word( self, diff --git a/backend/app/subgraphs/dictionary/nodes.py b/backend/app/subgraphs/dictionary/nodes.py index 7b76099..3b094a6 100644 --- a/backend/app/subgraphs/dictionary/nodes.py +++ b/backend/app/subgraphs/dictionary/nodes.py @@ -8,7 +8,7 @@ from datetime import datetime import random # 公共工具 -from app.core import ( +from ...core import ( MarkdownFormatter ) diff --git a/backend/app/subgraphs/news_analysis/api_client.py b/backend/app/subgraphs/news_analysis/api_client.py index 8ae19ba..7b35174 100644 --- a/backend/app/subgraphs/news_analysis/api_client.py +++ b/backend/app/subgraphs/news_analysis/api_client.py @@ -23,20 +23,19 @@ class NewsAPIClient: news_repository: Optional[Any] = None def __post_init__(self): - """初始化后设置MCP""" - import asyncio - try: - asyncio.create_task(self._init_mcp()) - except RuntimeError: - pass - + """初始化后设置(延迟初始化标志)""" + self._mcp_initialized = False # 延迟初始化标志 + async def _init_mcp(self): - """初始化MCP系统""" + """初始化MCP系统(延迟初始化)""" + if self._mcp_initialized: + return # 已初始化,跳过 if not mcp_manager.get_adapter("news"): mcp_manager.register_adapter( NewsAdapter(news_repo=self.news_repository) ) await mcp_manager.initialize() + self._mcp_initialized = True async def query_news( self, diff --git a/backend/app/subgraphs/news_analysis/nodes.py b/backend/app/subgraphs/news_analysis/nodes.py index fbab16b..99d6d17 100644 --- a/backend/app/subgraphs/news_analysis/nodes.py +++ b/backend/app/subgraphs/news_analysis/nodes.py @@ -7,7 +7,7 @@ from typing import Dict, Any from datetime import datetime # 公共工具 -from app.core import MarkdownFormatter +from ...core import MarkdownFormatter from .state import ( NewsAnalysisState, diff --git a/backend/app/utils/logging.py b/backend/app/utils/logging.py index 7802895..9cd48b8 100644 --- a/backend/app/utils/logging.py +++ b/backend/app/utils/logging.py @@ -3,9 +3,9 @@ LangGraph 节点日志工具模块 提供状态流转追踪和 LLM 输入输出打印功能 """ -from app.config import ENABLE_GRAPH_TRACE -from app.logger import debug, info -from app.main_graph.state import MainGraphState +from ..config import ENABLE_GRAPH_TRACE +from ..logger import debug, info +from ..main_graph.state import MainGraphState def log_state_change(node_name: str, state: MainGraphState, prefix: str = "进入"): @@ -17,7 +17,7 @@ def log_state_change(node_name: str, state: MainGraphState, prefix: str = "进 state: 当前状态 prefix: 日志前缀("进入" 或 "离开") """ - from app.logger import info + from backend.app.logger import info messages = state.messages msg_count = len(messages) diff --git a/rag_indexer/cli.py b/rag_indexer/cli.py index 5aa6076..ae10785 100755 --- a/rag_indexer/cli.py +++ b/rag_indexer/cli.py @@ -35,6 +35,7 @@ def get_input_path() -> Path: return Path(sys.argv[1]) # 默认测试路径(可按需修改) return Path("data/corpus/三国演义.txt") + #return Path("data/user_docs/doublestory.txt") async def main(): diff --git a/rag_indexer/reset_qdrant.py b/rag_indexer/reset_qdrant.py deleted file mode 100644 index 305e11f..0000000 --- a/rag_indexer/reset_qdrant.py +++ /dev/null @@ -1,34 +0,0 @@ -#!/usr/bin/env python3 -""" -删除 Qdrant 集合并重新索引 -""" - -import asyncio -import os -import sys - -from backend.rag_core import QdrantHybridStore - - -async def delete_and_recreate(): - """删除并重新创建集合""" - print("="*70) - print("删除旧集合并重新创建...") - print("="*70) - - vs = QdrantHybridStore(collection_name="rag_documents") - - # 删除旧集合 - try: - vs.delete_collection() - print("✅ 旧集合已删除") - except Exception as e: - print(f"⚠️ 删除集合时出错(可能不存在): {e}") - - # 重新创建 - vs.create_collection() - print("✅ 新集合已创建") - - -if __name__ == "__main__": - asyncio.run(delete_and_recreate()) diff --git a/rag_indexer/splitters.py b/rag_indexer/splitters.py index 006e8ab..1b1eadd 100644 --- a/rag_indexer/splitters.py +++ b/rag_indexer/splitters.py @@ -17,10 +17,8 @@ class SplitterType(str, Enum): PARENT_CHILD = "parent_child" -# ---------- 配置数据类,统一参数 ---------- @dataclass class RecursiveSplitterConfig: - """递归字符切分器配置""" chunk_size: int = 500 chunk_overlap: int = 50 separators: List[str] = field(default_factory=lambda: ["\n\n", "\n", "。", "!", "?", " ", ""]) @@ -30,33 +28,31 @@ class RecursiveSplitterConfig: @dataclass class SemanticSplitterConfig: - """语义切分器配置,仅包含 SemanticChunker 支持的参数。""" embeddings: Any buffer_size: int = 1 add_start_index: bool = False breakpoint_threshold_type: str = "percentile" - breakpoint_threshold_amount: Optional[float] = None + breakpoint_threshold_amount: float = 0.6 # 非 None,切分更积极 number_of_chunks: Optional[int] = None - sentence_split_regex: str = r"(?<=[.?!。?!])\s+" + sentence_split_regex: str = r"(?<=[。!?;.!?;])" # 中文友好 min_chunk_size: int = 100 + @dataclass class ParentChildSplitterConfig: - """父子切分器配置""" - embeddings: Any # 子块语义切分所需 - parent_chunk_size: int = 1000 - parent_chunk_overlap: int = 100 - child_buffer_size: int = 1 - child_breakpoint_threshold_type: str = "percentile" - child_breakpoint_threshold_amount: Optional[float] = None - child_min_chunk_size: int = 100 - child_max_chunk_size: Optional[int] = 200 + embeddings: Any + # 语义切分(用于父块) + semantic_threshold_type: str = "percentile" + semantic_threshold_amount: float = 0.6 + semantic_buffer_size: int = 1 + semantic_min_chunk_size: int = 100 + # 子块(递归字符切分) + child_chunk_size: int = 400 + child_chunk_overlap: int = 50 -# ---------- 适配器:让 SemanticChunker 实现 TextSplitter 接口 ---------- +# ---------- 适配器 ---------- class SemanticChunkerAdapter(TextSplitter): - """将 SemanticChunker 适配为 LangChain TextSplitter 接口。""" - def __init__(self, config: SemanticSplitterConfig, **kwargs): super().__init__(**kwargs) self._config = config @@ -86,12 +82,8 @@ class SemanticChunkerAdapter(TextSplitter): return result -# ---------- 工厂函数,统一创建切分器 ---------- +# ---------- 工厂函数 ---------- def get_splitter(splitter_type: SplitterType, **kwargs) -> TextSplitter: - """ - 根据类型创建切分器。 - 支持传入配置对象或直接参数。 - """ if splitter_type == SplitterType.RECURSIVE: config = RecursiveSplitterConfig( chunk_size=kwargs.get("chunk_size", 500), @@ -114,98 +106,90 @@ def get_splitter(splitter_type: SplitterType, **kwargs) -> TextSplitter: if "config" in kwargs and isinstance(kwargs["config"], SemanticSplitterConfig): config = kwargs["config"] else: - # 过滤出 SemanticSplitterConfig 支持的字段 - config_kwargs = { - "embeddings": embeddings, - "buffer_size": kwargs.get("buffer_size", 1), - "breakpoint_threshold_type": kwargs.get("breakpoint_threshold_type", "percentile"), - "breakpoint_threshold_amount": kwargs.get("breakpoint_threshold_amount"), - "number_of_chunks": kwargs.get("number_of_chunks"), - "min_chunk_size": kwargs.get("min_chunk_size", 100), - } - config = SemanticSplitterConfig(**config_kwargs) + config = SemanticSplitterConfig( + embeddings=embeddings, + buffer_size=kwargs.get("buffer_size", 1), + breakpoint_threshold_type=kwargs.get("breakpoint_threshold_type", "percentile"), + breakpoint_threshold_amount=kwargs.get("breakpoint_threshold_amount", 0.6), + number_of_chunks=kwargs.get("number_of_chunks"), + min_chunk_size=kwargs.get("min_chunk_size", 100), + ) return SemanticChunkerAdapter(config) elif splitter_type == SplitterType.PARENT_CHILD: - # 父子切分器在 builder 中单独处理,不通过本函数创建 - raise ValueError("父子切分器应通过 IndexBuilder 创建,不支持 get_splitter 直接构建") + raise ValueError("父子切分器应通过 ParentChildSplitter 直接创建") else: raise ValueError(f"不支持的切分器类型: {splitter_type}") - -# ---------- 父子切分器实现 ---------- + + +# ---------- 父子切分器 ---------- class ParentChildSplitter: """ - 将文档切分为父块(大块,用于上下文)和子块(小块,用于索引检索)。 - 内部维护父子块之间的映射关系。 + 切分流程: + 1. 语义切分 → 父块 + 2. 递归字符切分 → 子块 """ def __init__(self, config: ParentChildSplitterConfig): self.config = config - # 父块使用递归字符切分 - self.parent_splitter = RecursiveCharacterTextSplitter( - chunk_size=config.parent_chunk_size, - chunk_overlap=config.parent_chunk_overlap, - ) - # 子块使用语义切分 + + # 语义切分(父块) semantic_config = SemanticSplitterConfig( embeddings=config.embeddings, - buffer_size=config.child_buffer_size, - breakpoint_threshold_type=config.child_breakpoint_threshold_type, - breakpoint_threshold_amount=config.child_breakpoint_threshold_amount, - min_chunk_size=config.child_min_chunk_size, + buffer_size=config.semantic_buffer_size, + breakpoint_threshold_type=config.semantic_threshold_type, + breakpoint_threshold_amount=config.semantic_threshold_amount, + min_chunk_size=config.semantic_min_chunk_size, + ) + self.semantic_splitter = SemanticChunkerAdapter(semantic_config) + + # 递归字符切分(子块,大小由 child_chunk_size 控制) + self.recursive_splitter = RecursiveCharacterTextSplitter( + chunk_size=config.child_chunk_size, + chunk_overlap=config.child_chunk_overlap, + separators=["\n\n", "\n", "。", "!", "?", ";", ",", " ", ""] ) - self.child_splitter = SemanticChunkerAdapter(semantic_config) - # 存储父子块映射关系(可选) self.parent_to_children: Dict[str, List[str]] = {} self.child_to_parent: Dict[str, str] = {} def split_documents(self, documents: List[Document]) -> Tuple[List[Document], List[Document]]: - """ - 返回: - (父块列表, 子块列表) - 同时填充内部映射字典。 - """ - parent_chunks = self.parent_splitter.split_documents(documents) - child_chunks = self.child_splitter.split_documents(documents) + parent_chunks = [] + child_chunks = [] - # 建立映射关系(简化示例:根据文本包含关系粗略匹配,实际需更精确的算法) - # 这里仅作示意,生产环境建议使用 embedding 相似度或精确子串定位 - self._build_mappings(parent_chunks, child_chunks) + for doc in documents: + # Step 1: 语义切分(父块) + semantic_blocks = self.semantic_splitter.split_text(doc.page_content) + + for p_idx, semantic_block in enumerate(semantic_blocks): + parent_id = f"parent_{len(parent_chunks)}" + parent_doc = Document( + page_content=semantic_block, + metadata={**doc.metadata, "id": parent_id, "chunk_index": p_idx} + ) + parent_chunks.append(parent_doc) + + # Step 2: 递归字符切分(子块) + sub_chunks = self.recursive_splitter.split_text(semantic_block) + + for c_idx, sub_chunk in enumerate(sub_chunks): + child_id = f"child_{len(child_chunks)}" + child_doc = Document( + page_content=sub_chunk, + metadata={**doc.metadata, "id": child_id, "parent_id": parent_id, "child_index": c_idx} + ) + child_chunks.append(child_doc) + + self.child_to_parent[child_id] = parent_id + if parent_id not in self.parent_to_children: + self.parent_to_children[parent_id] = [] + self.parent_to_children[parent_id].append(child_id) return parent_chunks, child_chunks - def _build_mappings(self, parents: List[Document], children: List[Document]) -> None: - """ - 根据文本内容建立父子映射。 - 本方法为简化实现,实际使用时请替换为更可靠的匹配逻辑。 - """ - self.parent_to_children.clear() - self.child_to_parent.clear() - - # 为每个父块生成唯一 ID(若无则使用索引) - for p_idx, parent in enumerate(parents): - parent_id = parent.metadata.get("id", f"parent_{p_idx}") - parent.metadata["id"] = parent_id - self.parent_to_children[parent_id] = [] - - # 将每个子块分配给包含其文本的第一个父块 - for c_idx, child in enumerate(children): - child_id = child.metadata.get("id", f"child_{c_idx}") - child.metadata["id"] = child_id - for parent in parents: - if child.page_content in parent.page_content: - parent_id = parent.metadata["id"] - self.parent_to_children[parent_id].append(child_id) - self.child_to_parent[child_id] = parent_id - child.metadata["parent_id"] = parent_id - break - def get_parent_for_child(self, child_id: str) -> Optional[str]: - """根据子块 ID 获取父块 ID""" return self.child_to_parent.get(child_id) def get_children_for_parent(self, parent_id: str) -> List[str]: - """根据父块 ID 获取所有子块 ID""" - return self.parent_to_children.get(parent_id, []) \ No newline at end of file + return self.parent_to_children.get(parent_id, []) diff --git a/backend/app/rag/evaluate.py b/tools/evaluate.py similarity index 100% rename from backend/app/rag/evaluate.py rename to tools/evaluate.py diff --git a/tools/run.py b/tools/run.py deleted file mode 100644 index d6a6ae9..0000000 --- a/tools/run.py +++ /dev/null @@ -1,15 +0,0 @@ -#!/usr/bin/env python3 -"""统一入口:设置路径后运行测试""" -import sys -from pathlib import Path -from dotenv import load_dotenv - -# 路径设置 - 只添加 backend 目录 -project_root = Path(__file__).resolve().parent.parent -backend_path = project_root -sys.path.insert(0, str(backend_path)) -load_dotenv(project_root / ".env") - -if __name__ == "__main__": - from tools.test import test_tavily_search - test_tavily_search.main() diff --git a/tools/test/test_graph_branches.py b/tools/test/test_graph_branches.py index a25aa80..687fd45 100644 --- a/tools/test/test_graph_branches.py +++ b/tools/test/test_graph_branches.py @@ -2,29 +2,24 @@ """ 主图完整测试 - 覆盖各个分支 """ -import sys + import asyncio -from pathlib import Path -from dotenv import load_dotenv -# 添加 backend 到路径 -sys.path.insert(0, str(Path(__file__).parent.parent.parent / "backend")) - -from app.main_graph.state import MainGraphState, CurrentAction -from app.main_graph.utils.main_graph_builder import build_react_main_graph -from app.model_services.chat_services import get_all_chat_services -from app.main_graph.tools.graph_tools import AVAILABLE_TOOLS -from app.main_graph.utils.rag_initializer import init_rag_tool +from backend.app.main_graph.state import MainGraphState, CurrentAction +from backend.app.main_graph.main_graph_builder import build_react_main_graph +from backend.app.model_services.chat_services import get_all_chat_services +from backend.app.main_graph.tools.graph_tools import AVAILABLE_TOOLS +from backend.app.main_graph.utils.rag_initializer import init_rag_tool # ========== 测试用例配置 ========== TEST_CASES = [ - # # 测试1: 简单闲聊 - 应该走 fast_chitchat - # { - # "name": "闲聊测试", - # "query": "你好!", - # "description": "测试快速闲聊分支" - # }, + # 测试1: 简单闲聊 - 应该走 fast_chitchat + { + "name": "闲聊测试", + "query": "你好!", + "description": "测试快速闲聊分支" + }, # 测试2: 知识查询 - 应该走 fast_rag,然后可能升级到 react { "name": "知识查询测试", @@ -37,12 +32,12 @@ TEST_CASES = [ "query": "请帮我分析:如果我有10万元,想要在一年内获得15%的收益,有哪些低风险的投资方案?", "description": "测试 React 循环推理分支" }, - # # 测试4: 需要工具调用的问题 - # { - # "name": "工具调用测试", - # "query": "搜索一下今天的天气怎么样", - # "description": "测试工具调用分支" - # }, + # 测试4: 需要工具调用的问题 + { + "name": "联网工具调用测试", + "query": "搜索一下今天的天气怎么样", + "description": "测试工具调用分支" + }, # 测试5: 带记忆的对话 { "name": "记忆测试", @@ -64,22 +59,18 @@ async def setup_test_environment(): if not chat_services: raise RuntimeError("没有可用的 LLM 服务") - llm = list(chat_services.values())[0] - print(f"✓ 使用 LLM: {list(chat_services.keys())[0]}") + print(f"✓ 可用模型: {list(chat_services.keys())}") # 初始化 RAG 工具 - def create_local_llm(): - return llm - - rag_tool = await init_rag_tool(create_local_llm) + rag_tool = await init_rag_tool() tools = AVAILABLE_TOOLS.copy() if rag_tool: tools.append(rag_tool) print(f"✓ RAG 工具初始化成功") - # 构建图 + # 构建图(使用新的 API: chat_services 而不是 llm) graph = build_react_main_graph( - llm=llm, + chat_services=chat_services, tools=tools, use_hybrid_router=True ).compile() diff --git a/tools/visualize_graph.py b/tools/visualize_graph.py index bb0557d..ae39f92 100644 --- a/tools/visualize_graph.py +++ b/tools/visualize_graph.py @@ -10,16 +10,22 @@ from pathlib import Path # 路径设置 PROJECT_ROOT = Path(__file__).parent.parent BACKEND_DIR = PROJECT_ROOT / "backend" -sys.path.insert(0, str(PROJECT_ROOT)) +sys.path.insert(0, str(BACKEND_DIR)) + +import warnings +# 抑制 WebSocket 弃用警告 +warnings.filterwarnings("ignore", category=DeprecationWarning, module="websockets") +warnings.filterwarnings("ignore", category=DeprecationWarning, module="uvicorn.protocols.websockets") from dotenv import load_dotenv load_dotenv(PROJECT_ROOT / ".env") import asyncio -from backend.app.agent.agent_service import AIAgentService +from backend.app.agent.agent_service import AIAgentService, create_serde from backend.app.config import DB_URI from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver + async def visualize_graph(): """可视化 LangGraph 结构""" print("=" * 80) @@ -28,9 +34,7 @@ async def visualize_graph(): print(f"项目根目录: {PROJECT_ROOT}") print(f"Backend 目录: {BACKEND_DIR}") - - - async with AsyncPostgresSaver.from_conn_string(DB_URI) as checkpointer: + async with AsyncPostgresSaver.from_conn_string(DB_URI, serde=create_serde()) as checkpointer: await checkpointer.setup() # 创建服务实例