This commit is contained in:
@@ -3,6 +3,6 @@ AI Agent 应用模块
|
||||
"""
|
||||
|
||||
from .agent.agent_service import AIAgentService
|
||||
from app.main_graph.tools.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
|
||||
from .main_graph.tools.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
|
||||
|
||||
__all__ = ["AIAgentService", "AVAILABLE_TOOLS", "TOOLS_BY_NAME"]
|
||||
|
||||
@@ -7,6 +7,9 @@ import json
|
||||
import asyncio
|
||||
from typing import AsyncGenerator, Dict, Any, Optional, Tuple
|
||||
|
||||
# LangGraph 序列化器(修复 checkpoint 反序列化警告)
|
||||
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
|
||||
|
||||
# 本地模块
|
||||
from ..model_services import get_cached_chat_services
|
||||
from ..main_graph.main_graph_builder import build_react_main_graph
|
||||
@@ -18,6 +21,23 @@ from ..logger import debug, info, warning, error
|
||||
from ..main_graph.state import MainGraphState, CurrentAction
|
||||
|
||||
|
||||
# ========== 自定义类型序列化器 ==========
|
||||
def create_serde() -> JsonPlusSerializer:
|
||||
"""创建带自定义类型注册的序列化器"""
|
||||
from backend.app.core.intent import ReasoningAction, RetrievalConfig, ReasoningResult
|
||||
|
||||
return JsonPlusSerializer(
|
||||
allowed_msgpack_modules=[
|
||||
("app.core.intent", "ReasoningAction"),
|
||||
("app.core.intent", "RetrievalConfig"),
|
||||
("app.core.intent", "ReasoningResult"),
|
||||
("app.main_graph.state", "CurrentAction"),
|
||||
("app.main_graph.state", "ErrorSeverity"),
|
||||
("app.main_graph.state", "ErrorRecord"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class AIAgentService:
|
||||
def __init__(self, checkpointer):
|
||||
self.checkpointer = checkpointer
|
||||
@@ -55,6 +75,7 @@ class AIAgentService:
|
||||
tools=self.tools,
|
||||
mem0_client=self.mem0_client
|
||||
)
|
||||
# 注意:serde 已在创建 checkpointer 时传入,这里只需传入 checkpointer
|
||||
self.graph = graph_builder.compile(checkpointer=self.checkpointer)
|
||||
info(f"✅ 单图初始化完成")
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any
|
||||
from app.logger import error # 保持兼容,或者替换为 logger
|
||||
from ..logger import error # 保持兼容,或者替换为 logger
|
||||
|
||||
class ThreadHistoryService:
|
||||
"""线程历史查询服务"""
|
||||
|
||||
@@ -3,8 +3,13 @@ FastAPI 后端 - 支持动态模型切换,使用 PostgreSQL 持久化记忆
|
||||
采用依赖注入模式,优雅管理资源生命周期
|
||||
"""
|
||||
|
||||
import warnings
|
||||
# 抑制 WebSocket 弃用警告(websockets 库升级导致,uvicorn 尚未跟进)
|
||||
warnings.filterwarnings("ignore", category=DeprecationWarning, module="websockets")
|
||||
warnings.filterwarnings("ignore", category=DeprecationWarning, module="uvicorn.protocols.websockets")
|
||||
|
||||
import os
|
||||
from app.config import DB_URI, BACKEND_PORT
|
||||
from ..config import DB_URI, BACKEND_PORT
|
||||
import uuid
|
||||
import json
|
||||
from contextlib import asynccontextmanager
|
||||
@@ -15,26 +20,26 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
from .agent.agent_service import AIAgentService
|
||||
from .agent.agent_service import AIAgentService, create_serde
|
||||
from .agent.history import ThreadHistoryService
|
||||
from app.core.human_review import (
|
||||
from ..core.human_review import (
|
||||
ReviewManager,
|
||||
InMemoryReviewStore,
|
||||
ReviewStatus,
|
||||
HumanReview
|
||||
)
|
||||
from app.subgraphs.contact.api_client import ContactAPIClient
|
||||
from app.subgraphs.dictionary.api_client import DictionaryAPIClient
|
||||
from app.subgraphs.news_analysis.api_client import NewsAPIClient
|
||||
from ..subgraphs.contact.api_client import ContactAPIClient
|
||||
from ..subgraphs.dictionary.api_client import DictionaryAPIClient
|
||||
from ..subgraphs.news_analysis.api_client import NewsAPIClient
|
||||
from .db.init_db import init_subgraph_tables
|
||||
from .db.models import ContactRepository, DictionaryRepository, NewsRepository
|
||||
from app.logger import info, error
|
||||
from ..logger import info, error
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""应用生命周期管理:创建并注入全局服务"""
|
||||
# 1. 创建数据库连接池并初始化表(仅 checkpointer)
|
||||
async with AsyncPostgresSaver.from_conn_string(DB_URI) as checkpointer:
|
||||
async with AsyncPostgresSaver.from_conn_string(DB_URI, serde=create_serde()) as checkpointer:
|
||||
await checkpointer.setup()
|
||||
|
||||
# 1.5 初始化子图表
|
||||
|
||||
@@ -30,7 +30,7 @@ from .visualization import (
|
||||
# 为了兼容性,添加 classify_intent 函数
|
||||
def classify_intent(user_input: str, context: str = None):
|
||||
"""兼容旧代码的 classify_intent 函数"""
|
||||
from app.core.intent_classifier import get_intent_classifier
|
||||
from backend.app.core.intent_classifier import get_intent_classifier
|
||||
import asyncio
|
||||
classifier = get_intent_classifier()
|
||||
try:
|
||||
|
||||
@@ -94,7 +94,7 @@ class ReactIntentReasoner:
|
||||
def _get_llm_service(self):
|
||||
"""懒加载 LLM 服务(避免循环导入)"""
|
||||
if self._llm_service is None:
|
||||
from app.model_services.chat_services import get_chat_service, get_small_llm_service
|
||||
from backend.app.model_services.chat_services import get_chat_service, get_small_llm_service
|
||||
if self._use_small_llm:
|
||||
self._llm_service = get_small_llm_service()
|
||||
else:
|
||||
|
||||
@@ -89,7 +89,7 @@ class WebSearchTool:
|
||||
def _search_tavily(self, query: str, max_results: int) -> List[SearchResult]:
|
||||
"""使用 Tavily API 搜索"""
|
||||
from tavily import TavilyClient
|
||||
from app.config import TAVILY_API_KEY, TAVILY_MAX_RESULTS
|
||||
from backend.app.config import TAVILY_API_KEY, TAVILY_MAX_RESULTS
|
||||
|
||||
if not TAVILY_API_KEY:
|
||||
raise ValueError("TAVILY_API_KEY 未配置")
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
"""
|
||||
|
||||
import os
|
||||
from app.config import LOG_LEVEL, DEBUG
|
||||
from .config import LOG_LEVEL, DEBUG
|
||||
import logging
|
||||
from typing import Any
|
||||
# 根据环境变量控制是否显示详细调试信息
|
||||
|
||||
@@ -4,12 +4,13 @@
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
|
||||
async def dispatch_custom_event(
|
||||
event_name: str,
|
||||
data: Dict[str, Any],
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
) -> None:
|
||||
"""
|
||||
安全地发送自定义事件,忽略发送失败
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
错误处理节点 - 处理子图/工具调用错误
|
||||
"""
|
||||
|
||||
from app.main_graph.state import MainGraphState, ErrorSeverity
|
||||
from app.logger import info
|
||||
from ...main_graph.state import MainGraphState, ErrorSeverity
|
||||
from ...logger import info
|
||||
|
||||
|
||||
def error_handling_node(state: MainGraphState) -> MainGraphState:
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from ..state import MainGraphState
|
||||
from ...logger import info, debug
|
||||
@@ -28,7 +29,7 @@ CHITCHAT_KEYWORDS = {
|
||||
|
||||
|
||||
# ========== 闲聊节点 ==========
|
||||
async def fast_chitchat_node(state: MainGraphState, config: Optional[dict] = None) -> MainGraphState:
|
||||
async def fast_chitchat_node(state: MainGraphState, config: Optional[RunnableConfig] = None) -> MainGraphState:
|
||||
"""快速闲聊节点"""
|
||||
state.current_phase = "fast_chitchat"
|
||||
query = state.user_query or ""
|
||||
@@ -69,14 +70,14 @@ def _match_chitchat_template(query: str) -> str:
|
||||
|
||||
|
||||
# ========== 快速 RAG 节点 ==========
|
||||
async def fast_rag_node(state: MainGraphState, config: Optional[dict] = None) -> MainGraphState:
|
||||
async def fast_rag_node(state: MainGraphState, config: Optional[RunnableConfig] = None) -> MainGraphState:
|
||||
"""快速 RAG 节点:只负责 RAG 检索,然后交给 llm_call 生成回答"""
|
||||
state.current_phase = "fast_rag"
|
||||
query = state.user_query or ""
|
||||
info(f"[Fast RAG] 开始处理: {query[:50]}")
|
||||
|
||||
# 获取 RAG 工具
|
||||
from app.main_graph.utils.rag_initializer import get_rag_tool
|
||||
from backend.app.main_graph.utils.rag_initializer import get_rag_tool
|
||||
rag_tool = get_rag_tool()
|
||||
info(f"[Fast RAG] 获取到 rag_tool: {rag_tool is not None}")
|
||||
|
||||
@@ -134,7 +135,7 @@ async def _generate_fast_answer(state: MainGraphState, query: str) -> MainGraphS
|
||||
请给出简洁、准确的回答:"""
|
||||
|
||||
# 使用流式输出
|
||||
from app.main_graph.config import get_stream_writer
|
||||
from backend.app.main_graph.config import get_stream_writer
|
||||
writer = get_stream_writer()
|
||||
|
||||
full_content = ""
|
||||
@@ -164,7 +165,7 @@ async def _generate_fast_answer(state: MainGraphState, query: str) -> MainGraphS
|
||||
|
||||
|
||||
# ========== 快速工具节点 ==========
|
||||
async def fast_tool_node(state: MainGraphState, config: Optional[dict] = None) -> MainGraphState:
|
||||
async def fast_tool_node(state: MainGraphState, config: Optional[RunnableConfig] = None) -> MainGraphState:
|
||||
"""快速工具节点"""
|
||||
state.current_phase = "fast_tool"
|
||||
|
||||
|
||||
@@ -6,9 +6,9 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
# 本地模块
|
||||
from app.main_graph.state import MainGraphState
|
||||
from app.utils.logging import log_state_change
|
||||
from app.logger import info, warning
|
||||
from ...main_graph.state import MainGraphState
|
||||
from ...utils.logging import log_state_change
|
||||
from ...logger import info, warning
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
@@ -35,7 +35,7 @@ async def finalize_node(state: MainGraphState, config: RunnableConfig) -> Dict[s
|
||||
|
||||
try:
|
||||
# 获取流式写入器并发送完成事件
|
||||
from app.main_graph.config import get_stream_writer
|
||||
from backend.app.main_graph.config import get_stream_writer
|
||||
writer = get_stream_writer()
|
||||
|
||||
# 只在 writer 存在且不是 noop 时才发送
|
||||
|
||||
@@ -8,6 +8,7 @@ import json
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from ..state import MainGraphState
|
||||
from ...logger import info, debug
|
||||
@@ -157,7 +158,7 @@ def _default_result() -> HybridRouterResult:
|
||||
|
||||
|
||||
# ========== 主路由节点 ==========
|
||||
async def hybrid_router_node(state: MainGraphState, config: Optional[dict] = None) -> MainGraphState:
|
||||
async def hybrid_router_node(state: MainGraphState, config: Optional[RunnableConfig] = None) -> MainGraphState:
|
||||
"""混合路由节点:前置路由,决定走快速路径还是 React 循环"""
|
||||
state.current_phase = "hybrid_router"
|
||||
query = state.user_query or ""
|
||||
|
||||
@@ -9,10 +9,10 @@ from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
# 本地模块
|
||||
from app.main_graph.state import MainGraphState
|
||||
from app.agent.prompts import create_system_prompt
|
||||
from app.utils.logging import log_state_change
|
||||
from app.logger import debug, info, error
|
||||
from ...main_graph.state import MainGraphState
|
||||
from ...agent.prompts import create_system_prompt
|
||||
from ...utils.logging import log_state_change
|
||||
from ...logger import debug, info, error
|
||||
|
||||
|
||||
def create_dynamic_llm_call_node(chat_services: Dict[str, BaseChatModel], tools: list):
|
||||
@@ -82,6 +82,12 @@ def create_dynamic_llm_call_node(chat_services: Dict[str, BaseChatModel], tools:
|
||||
try:
|
||||
# 添加上下文到消息
|
||||
messages_with_context = list(state.messages)
|
||||
info(f"[llm_call] 原始消息数量: {len(messages_with_context)}")
|
||||
for i, msg in enumerate(messages_with_context):
|
||||
msg_type = getattr(msg, 'type', 'unknown')
|
||||
msg_content = getattr(msg, 'content', '')[:100] if hasattr(msg, 'content') else str(msg)[:100]
|
||||
info(f"[llm_call] msg[{i}] type={msg_type}, content={repr(msg_content)}")
|
||||
|
||||
if state.rag_context:
|
||||
from langchain_core.messages import SystemMessage
|
||||
rag_system_msg = SystemMessage(content=f"以下是检索到的相关信息:\n{state.rag_context}")
|
||||
@@ -93,11 +99,13 @@ def create_dynamic_llm_call_node(chat_services: Dict[str, BaseChatModel], tools:
|
||||
break
|
||||
if not inserted:
|
||||
messages_with_context.insert(0, rag_system_msg)
|
||||
|
||||
info(f"[llm_call] RAG上下文已添加,长度: {len(state.rag_context)}")
|
||||
|
||||
# 恢复为:手动进行 astream,并将所有的 chunk 拼接成最终的 response 返回。
|
||||
# LangGraph 会自动监听这期间产生的所有 token。
|
||||
chain = prompt | llm_with_tools
|
||||
chunks = []
|
||||
info(f"[llm_call] 开始调用 LLM astream...")
|
||||
async for chunk in chain.astream(
|
||||
{
|
||||
"messages": messages_with_context,
|
||||
@@ -106,7 +114,26 @@ def create_dynamic_llm_call_node(chat_services: Dict[str, BaseChatModel], tools:
|
||||
config=config
|
||||
):
|
||||
chunks.append(chunk)
|
||||
|
||||
|
||||
info(f"[llm_call] LLM astream 完成,共收到 {len(chunks)} 个 chunks")
|
||||
for i, chunk in enumerate(chunks[:10]): # 只打印前10个避免日志过多
|
||||
chunk_type = type(chunk).__name__
|
||||
chunk_content = getattr(chunk, 'content', '') if hasattr(chunk, 'content') else str(chunk)
|
||||
# 打印更多属性
|
||||
additional_kwargs = getattr(chunk, 'additional_kwargs', {}) or {}
|
||||
response_metadata = getattr(chunk, 'response_metadata', {}) or {}
|
||||
# 打印所有属性
|
||||
info(f"[llm_call] chunk[{i}] type={chunk_type}")
|
||||
info(f"[llm_call] chunk[{i}] content长度={len(chunk_content) if chunk_content else 0}, content={repr(chunk_content[:200] if chunk_content else '')}")
|
||||
info(f"[llm_call] chunk[{i}] additional_kwargs={additional_kwargs}")
|
||||
info(f"[llm_call] chunk[{i}] response_metadata keys={list(response_metadata.keys()) if response_metadata else []}")
|
||||
info(f"[llm_call] chunk[{i}] response_metadata={response_metadata}")
|
||||
# 检查是否有其他可能存储内容的属性
|
||||
if hasattr(chunk, 'tool_call_chunks'):
|
||||
info(f"[llm_call] chunk[{i}] tool_call_chunks={chunk.tool_call_chunks}")
|
||||
if hasattr(chunk, 'usage_metadata'):
|
||||
info(f"[llm_call] chunk[{i}] usage_metadata={chunk.usage_metadata}")
|
||||
|
||||
# 将所有 chunk 合并成最终的 AIMessage
|
||||
if chunks:
|
||||
response = chunks[0]
|
||||
@@ -114,6 +141,7 @@ def create_dynamic_llm_call_node(chat_services: Dict[str, BaseChatModel], tools:
|
||||
response = response + chunk
|
||||
else:
|
||||
response = AIMessage(content="")
|
||||
info(f"[llm_call] ⚠️ 警告: 没有收到任何 chunks!")
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from typing import Any, Dict
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from app.main_graph.state import MainGraphState
|
||||
from app.memory.mem0_client import Mem0Client
|
||||
from app.logger import info
|
||||
from ...main_graph.state import MainGraphState
|
||||
from ...memory.mem0_client import Mem0Client
|
||||
from ...logger import info
|
||||
|
||||
|
||||
# 全局变量,在 GraphBuilder 中注入
|
||||
|
||||
@@ -7,16 +7,17 @@ import time
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from app.main_graph.state import MainGraphState, ErrorRecord, ErrorSeverity
|
||||
from app.main_graph.utils.retry_utils import RAG_RETRY_CONFIG
|
||||
from app.logger import info
|
||||
from ...main_graph.state import MainGraphState, ErrorRecord, ErrorSeverity
|
||||
from ...main_graph.utils.retry_utils import RAG_RETRY_CONFIG
|
||||
from ...logger import info
|
||||
from ._utils import dispatch_custom_event, make_react_event
|
||||
|
||||
|
||||
def _get_rag_tool() -> Optional[callable]:
|
||||
"""获取 RAG 工具"""
|
||||
from app.main_graph.utils.rag_initializer import get_rag_tool
|
||||
from backend.app.main_graph.utils.rag_initializer import get_rag_tool
|
||||
return get_rag_tool()
|
||||
|
||||
|
||||
@@ -35,6 +36,9 @@ async def _rag_retrieve_core(state: MainGraphState, rag_tool: callable) -> MainG
|
||||
# 调用 RAG 工具
|
||||
rag_context = await rag_tool.ainvoke(retrieval_query)
|
||||
info(f"[RAG Core] 获取到 rag_context: {type(rag_context)}, 长度={len(rag_context) if rag_context else 0}")
|
||||
info(f"[RAG Core] ========== RAG 返回的知识内容 ==========")
|
||||
info(f"{rag_context[:500]}..." if len(rag_context) > 500 else rag_context)
|
||||
info(f"[RAG Core] ========================================")
|
||||
|
||||
state.rag_context = rag_context
|
||||
state.rag_docs = [{"source": "rag_retrieval", "content": rag_context}]
|
||||
@@ -47,7 +51,7 @@ async def _rag_retrieve_core(state: MainGraphState, rag_tool: callable) -> MainG
|
||||
|
||||
|
||||
# ========== RAG 检索节点 ==========
|
||||
async def rag_retrieve_node(state: MainGraphState, config: Optional[dict] = None) -> MainGraphState:
|
||||
async def rag_retrieve_node(state: MainGraphState, config: Optional[RunnableConfig] = None) -> MainGraphState:
|
||||
"""RAG 检索节点:带超时和重试"""
|
||||
state.current_phase = "rag_retrieving"
|
||||
start_time = time.time()
|
||||
@@ -156,7 +160,7 @@ async def rag_retrieve_node(state: MainGraphState, config: Optional[dict] = None
|
||||
return state
|
||||
|
||||
|
||||
async def rag_re_retrieve_node(state: MainGraphState, config: Optional[dict] = None) -> MainGraphState:
|
||||
async def rag_re_retrieve_node(state: MainGraphState, config: Optional[RunnableConfig] = None) -> MainGraphState:
|
||||
"""重新检索节点"""
|
||||
state.current_phase = "rag_re_retrieving"
|
||||
|
||||
|
||||
@@ -3,16 +3,17 @@ React 推理节点
|
||||
使用 intent.py 进行意图推理
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from app.core.intent import react_reason_async, ReasoningResult
|
||||
from app.main_graph.state import MainGraphState
|
||||
from app.logger import info
|
||||
from ...core.intent import react_reason_async, ReasoningResult
|
||||
from ...main_graph.state import MainGraphState
|
||||
from ...logger import info
|
||||
from ._utils import dispatch_custom_event, make_react_event
|
||||
|
||||
|
||||
async def react_reason_node(state: MainGraphState, config: Optional[Dict[str, Any]] = None) -> MainGraphState:
|
||||
async def react_reason_node(state: MainGraphState, config: Optional[RunnableConfig] = None) -> MainGraphState:
|
||||
"""React 模式推理节点:判断下一步做什么"""
|
||||
state.current_phase = "react_reasoning"
|
||||
state.reasoning_step += 1
|
||||
|
||||
@@ -6,10 +6,10 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
# 本地模块
|
||||
from app.main_graph.state import MainGraphState
|
||||
from app.memory.mem0_client import Mem0Client
|
||||
from app.utils.logging import log_state_change
|
||||
from app.logger import debug
|
||||
from ...main_graph.state import MainGraphState
|
||||
from ...memory.mem0_client import Mem0Client
|
||||
from ...utils.logging import log_state_change
|
||||
from ...logger import debug
|
||||
|
||||
|
||||
def create_retrieve_memory_node(mem0_client: Mem0Client):
|
||||
@@ -67,10 +67,10 @@ def create_retrieve_memory_node(mem0_client: Mem0Client):
|
||||
else:
|
||||
debug("🔍 [记忆检索] 未找到相关记忆")
|
||||
except Exception as e:
|
||||
from app.logger import warning
|
||||
from backend.app.logger import warning
|
||||
warning(f"⚠️ Mem0 检索失败: {e}")
|
||||
else:
|
||||
from app.logger import warning
|
||||
from backend.app.logger import warning
|
||||
warning("⚠️ Mem0 未初始化,跳过记忆检索")
|
||||
|
||||
memory_context = "\n\n".join(memory_text_parts) if memory_text_parts else "暂无用户信息"
|
||||
|
||||
@@ -10,9 +10,9 @@
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from app.core.intent import get_route_by_reasoning, ReasoningAction
|
||||
from app.main_graph.state import MainGraphState
|
||||
from app.logger import info
|
||||
from ...core.intent import get_route_by_reasoning, ReasoningAction
|
||||
from ...main_graph.state import MainGraphState
|
||||
from ...logger import info
|
||||
|
||||
|
||||
# ========== 初始化状态节点 ==========
|
||||
|
||||
@@ -6,10 +6,10 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
# 本地模块
|
||||
from app.main_graph.state import MainGraphState
|
||||
from app.memory.mem0_client import Mem0Client
|
||||
from app.utils.logging import log_state_change
|
||||
from app.logger import debug, info, error, warning
|
||||
from ...main_graph.state import MainGraphState
|
||||
from ...memory.mem0_client import Mem0Client
|
||||
from ...utils.logging import log_state_change
|
||||
from ...logger import debug, info, error, warning
|
||||
|
||||
|
||||
def create_summarize_node(mem0_client: Mem0Client):
|
||||
|
||||
@@ -6,12 +6,12 @@
|
||||
import asyncio
|
||||
from typing import Any, Dict
|
||||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
from app.main_graph.config import get_stream_writer
|
||||
from ...main_graph.config import get_stream_writer
|
||||
|
||||
# 本地模块
|
||||
from app.main_graph.state import MainGraphState
|
||||
from app.utils.logging import log_state_change
|
||||
from app.logger import debug, info
|
||||
from ...main_graph.state import MainGraphState
|
||||
from ...utils.logging import log_state_change
|
||||
from ...logger import debug, info
|
||||
|
||||
def create_tool_call_node(tools_by_name: Dict[str, Any]):
|
||||
"""
|
||||
|
||||
@@ -2,14 +2,15 @@
|
||||
联网搜索节点 - 执行搜索并将结果保存到状态
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from app.main_graph.state import MainGraphState, ErrorRecord, ErrorSeverity
|
||||
from app.logger import info
|
||||
from ...main_graph.state import MainGraphState, ErrorRecord, ErrorSeverity
|
||||
from ...logger import info
|
||||
|
||||
|
||||
async def web_search_node(state: MainGraphState, config: Optional[Dict[str, Any]] = None) -> MainGraphState:
|
||||
async def web_search_node(state: MainGraphState, config: Optional[RunnableConfig] = None) -> MainGraphState:
|
||||
"""
|
||||
联网搜索节点:执行搜索并将结果保存到状态
|
||||
"""
|
||||
@@ -39,7 +40,7 @@ async def web_search_node(state: MainGraphState, config: Optional[Dict[str, Any]
|
||||
search_query = reasoning_result.metadata.get("search_query", state.user_query) if reasoning_result else state.user_query
|
||||
|
||||
try:
|
||||
from app.core import web_search
|
||||
from backend.app.core import web_search
|
||||
|
||||
print(f"[WebSearch] 搜索: {search_query}")
|
||||
search_result = web_search(search_query, max_results=5)
|
||||
|
||||
@@ -5,6 +5,8 @@
|
||||
from typing import Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from .state import MainGraphState, ErrorRecord, ErrorSeverity
|
||||
from ..logger import info
|
||||
|
||||
@@ -19,7 +21,7 @@ def wrap_subgraph_for_error_handling(subgraph, name: str):
|
||||
|
||||
Returns: 包装后的节点函数
|
||||
"""
|
||||
async def wrapped_node(state: MainGraphState, config: Optional[Dict[str, Any]] = None) -> MainGraphState:
|
||||
async def wrapped_node(state: MainGraphState, config: Optional[RunnableConfig] = None) -> MainGraphState:
|
||||
# 发送子图开始事件
|
||||
if config:
|
||||
try:
|
||||
|
||||
@@ -20,7 +20,7 @@ def web_search_tool(query: str, max_results: int = 5) -> str:
|
||||
格式化的搜索结果,包含引用溯源
|
||||
"""
|
||||
try:
|
||||
from app.core import web_search
|
||||
from backend.app.core import web_search
|
||||
return web_search(query, max_results)
|
||||
except Exception as e:
|
||||
return f"联网搜索出错:{str(e)}"
|
||||
@@ -40,7 +40,7 @@ def generate_chart_tool(data_text: str, chart_type: str = "bar") -> str:
|
||||
格式化的图表输出(Mermaid 格式)
|
||||
"""
|
||||
try:
|
||||
from app.core import generate_chart
|
||||
from backend.app.core import generate_chart
|
||||
return generate_chart(data_text, chart_type)
|
||||
except Exception as e:
|
||||
return f"生成图表出错:{str(e)}\n\n请使用格式:标题,标签1:值1,标签2:值2,..."
|
||||
|
||||
@@ -22,13 +22,13 @@ def dictionary_tool(query: str, action: Optional[str] = None) -> str:
|
||||
格式化的结果文本
|
||||
"""
|
||||
try:
|
||||
from app.subgraphs.dictionary import (
|
||||
from backend.app.subgraphs.dictionary import (
|
||||
DictionaryState,
|
||||
DictionaryAction,
|
||||
parse_intent,
|
||||
format_result
|
||||
)
|
||||
from app.subgraphs.dictionary.nodes import (
|
||||
from backend.app.subgraphs.dictionary.nodes import (
|
||||
query_word, translate_text, extract_terms, get_daily_word
|
||||
)
|
||||
|
||||
@@ -87,13 +87,13 @@ def news_analysis_tool(query: str, action: Optional[str] = None) -> str:
|
||||
格式化的结果文本
|
||||
"""
|
||||
try:
|
||||
from app.subgraphs.news_analysis import (
|
||||
from backend.app.subgraphs.news_analysis import (
|
||||
NewsAnalysisState,
|
||||
NewsAction,
|
||||
parse_intent,
|
||||
format_result
|
||||
)
|
||||
from app.subgraphs.news_analysis.nodes import (
|
||||
from backend.app.subgraphs.news_analysis.nodes import (
|
||||
query_news, analyze_url, extract_keywords, generate_report
|
||||
)
|
||||
|
||||
@@ -150,13 +150,13 @@ def contact_tool(query: str, action: Optional[str] = None) -> str:
|
||||
格式化的结果文本
|
||||
"""
|
||||
try:
|
||||
from app.subgraphs.contact import (
|
||||
from backend.app.subgraphs.contact import (
|
||||
ContactState,
|
||||
ContactAction,
|
||||
parse_intent,
|
||||
format_result
|
||||
)
|
||||
from app.subgraphs.contact.nodes import (
|
||||
from backend.app.subgraphs.contact.nodes import (
|
||||
query_contact, add_contact, list_contacts
|
||||
)
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
# app/rag_initializer.py
|
||||
from app.rag.tools import create_rag_tool
|
||||
from app.rag.retriever import create_parent_hybrid_retriever
|
||||
from app.model_services import get_embedding_service
|
||||
from app.logger import info, warning
|
||||
from ...rag.tools import create_rag_tool
|
||||
from ...rag.retriever import create_parent_hybrid_retriever
|
||||
from ...model_services import get_embedding_service
|
||||
from ...logger import info, warning
|
||||
import sys
|
||||
|
||||
# 全局 RAG 工具
|
||||
@@ -38,7 +38,7 @@ async def init_rag_tool(force: bool = False):
|
||||
return _rag_tool
|
||||
|
||||
try:
|
||||
from app.model_services.chat_services import get_chat_service
|
||||
from backend.app.model_services.chat_services import get_chat_service
|
||||
|
||||
info("🔄 正在初始化 RAG 检索系统...")
|
||||
embeddings = get_embedding_service()
|
||||
|
||||
@@ -287,7 +287,7 @@ def create_retry_wrapper_for_node(
|
||||
time.sleep(delay)
|
||||
|
||||
# 所有重试都失败,更新状态错误信息
|
||||
from app.main_graph.state import ErrorRecord, ErrorSeverity
|
||||
from backend.app.main_graph.state import ErrorRecord, ErrorSeverity
|
||||
|
||||
error_record = ErrorRecord(
|
||||
error_type=f"{node_name}TimeoutError",
|
||||
|
||||
@@ -9,7 +9,7 @@ from typing import Optional, List
|
||||
|
||||
from mem0 import AsyncMemory
|
||||
|
||||
from app.config import (
|
||||
from ..config import (
|
||||
LLM_API_KEY,
|
||||
ZHIPUAI_API_KEY,
|
||||
VLLM_BASE_URL,
|
||||
@@ -21,9 +21,9 @@ from app.config import (
|
||||
ZHIPU_EMBEDDING_MODEL,
|
||||
ZHIPU_API_BASE,
|
||||
)
|
||||
from app.logger import info, warning, error
|
||||
from app.model_services import get_embedding_service
|
||||
from app.model_services.chat_services import get_chat_service
|
||||
from ..logger import info, warning, error
|
||||
from ..model_services import get_embedding_service
|
||||
from ..model_services.chat_services import get_chat_service
|
||||
|
||||
|
||||
class Mem0Client:
|
||||
@@ -48,7 +48,7 @@ class Mem0Client:
|
||||
info(f"✅ 嵌入服务可用,向量维度: {embedding_dim}")
|
||||
|
||||
# 构建 embedder 配置
|
||||
from app.model_services.embedding_services import (
|
||||
from backend.app.model_services.embedding_services import (
|
||||
LocalLlamaCppEmbeddingProvider,
|
||||
ZhipuEmbeddingProvider,
|
||||
)
|
||||
|
||||
@@ -23,7 +23,7 @@ from .base import (
|
||||
FallbackServiceChain,
|
||||
SingletonServiceManager
|
||||
)
|
||||
from app.config import (
|
||||
from ..config import (
|
||||
VLLM_BASE_URL,
|
||||
LLM_API_KEY,
|
||||
ZHIPUAI_API_KEY,
|
||||
@@ -203,7 +203,7 @@ class LocalSmallModelProvider(BaseServiceProvider[BaseChatModel]):
|
||||
"""
|
||||
|
||||
def __init__(self, model: str = None):
|
||||
from app.config import SMALL_LOCAL_MODEL_NAME, SMALL_VLLM_BASE_URL, SMALL_LLM_API_KEY
|
||||
from ..config import SMALL_LOCAL_MODEL_NAME, SMALL_VLLM_BASE_URL, SMALL_LLM_API_KEY
|
||||
super().__init__("local_small")
|
||||
self._model = model or SMALL_LOCAL_MODEL_NAME
|
||||
self._base_url = SMALL_VLLM_BASE_URL
|
||||
@@ -242,7 +242,7 @@ class DeepSeekSmallModelProvider(BaseServiceProvider[BaseChatModel]):
|
||||
"""
|
||||
|
||||
def __init__(self, model: str = None):
|
||||
from app.config import SMALL_DEEPSEEK_MODEL, SMALL_DEEPSEEK_API_KEY, SMALL_DEEPSEEK_API_BASE
|
||||
from ..config import SMALL_DEEPSEEK_MODEL, SMALL_DEEPSEEK_API_KEY, SMALL_DEEPSEEK_API_BASE
|
||||
super().__init__("deepseek_small")
|
||||
self._model = model or SMALL_DEEPSEEK_MODEL
|
||||
self._api_key = SMALL_DEEPSEEK_API_KEY
|
||||
|
||||
@@ -21,7 +21,7 @@ from .base import (
|
||||
FallbackServiceChain,
|
||||
SingletonServiceManager
|
||||
)
|
||||
from app.config import (
|
||||
from ..config import (
|
||||
LLAMACPP_EMBEDDING_URL,
|
||||
LLAMACPP_API_KEY,
|
||||
ZHIPUAI_API_KEY,
|
||||
|
||||
@@ -27,7 +27,7 @@ from .base import (
|
||||
FallbackServiceChain,
|
||||
SingletonServiceManager
|
||||
)
|
||||
from app.config import (
|
||||
from ..config import (
|
||||
LLAMACPP_RERANKER_URL,
|
||||
LLAMACPP_API_KEY,
|
||||
ZHIPUAI_API_KEY,
|
||||
@@ -92,20 +92,33 @@ class LocalLlamaCppRerankService(BaseRerankService):
|
||||
"documents": documents,
|
||||
}
|
||||
|
||||
logger.info(f"[LocalLlamaCppRerank] 调用 rerank API: {base}/rerank")
|
||||
logger.info(f"[LocalLlamaCppRerank] 请求 payload: query={query[:50]}, documents数量={len(documents)}")
|
||||
|
||||
with httpx.Client(timeout=120) as client:
|
||||
response = client.post(
|
||||
f"{base}/rerank",
|
||||
headers=headers,
|
||||
json=payload,
|
||||
)
|
||||
response.raise_for_status()
|
||||
logger.info(f"[LocalLlamaCppRerank] 响应状态码: {response.status_code}")
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"[LocalLlamaCppRerank] 请求失败: {response.status_code}")
|
||||
logger.error(f"[LocalLlamaCppRerank] 响应内容: {response.text[:500]}")
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
logger.info(f"[LocalLlamaCppRerank] 响应数据类型: {type(data)}")
|
||||
|
||||
if isinstance(data, dict) and "results" in data:
|
||||
results = data["results"]
|
||||
results_sorted = sorted(results, key=lambda x: x["index"])
|
||||
return [item["relevance_score"] for item in results_sorted]
|
||||
scores = [item["relevance_score"] for item in results_sorted]
|
||||
logger.info(f"[LocalLlamaCppRerank] 返回 {len(scores)} 个得分")
|
||||
return scores
|
||||
else:
|
||||
logger.error(f"[LocalLlamaCppRerank] 未知响应格式: {type(data)}")
|
||||
raise ValueError(f"未知的 rerank API 响应格式: {data}")
|
||||
|
||||
|
||||
@@ -207,11 +220,13 @@ class LLMFallbackRerankService(BaseRerankService):
|
||||
if not documents:
|
||||
return []
|
||||
|
||||
logger.info(f"[LLMFallbackRerank] 开始为 {len(documents)} 个文档打分")
|
||||
scores = []
|
||||
for doc in documents:
|
||||
for i, doc in enumerate(documents):
|
||||
score = self._score_single_document(query, doc)
|
||||
scores.append(score)
|
||||
|
||||
logger.info(f"[LLMFallbackRerank] doc[{i}] score={score:.4f}")
|
||||
|
||||
return scores
|
||||
|
||||
def _score_single_document(self, query: str, document: str) -> float:
|
||||
|
||||
@@ -13,7 +13,7 @@ RAG 检索与生成模块
|
||||
用户查询 → 多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档
|
||||
|
||||
示例用法:
|
||||
>>> from app.rag.rag import RAGPipeline, create_rag_tool
|
||||
>>> from backend.app.rag.rag import RAGPipeline, create_rag_tool
|
||||
>>> from rag_indexer.builder import IndexBuilder, IndexBuilderConfig
|
||||
>>> from langchain_openai import ChatOpenAI
|
||||
>>>
|
||||
|
||||
@@ -1,333 +0,0 @@
|
||||
"""
|
||||
RAG 评估模块
|
||||
用于计算 RAG 系统的召回率、相关性、准确率等指标
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import List, Dict, Tuple, Optional
|
||||
from dataclasses import dataclass
|
||||
from langchain_core.documents import Document
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrievalTestCase:
|
||||
"""检索测试用例"""
|
||||
query: str # 用户查询
|
||||
relevant_doc_ids: List[str] # 相关文档 ID 列表
|
||||
expected_answer: Optional[str] = None # 期望的答案(可选)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrievalMetrics:
|
||||
"""检索评估指标"""
|
||||
recall_at_k: Dict[int, float] # Recall@k,例如 {1: 0.8, 3: 0.9, 5: 1.0}
|
||||
precision_at_k: Dict[int, float] # Precision@k
|
||||
f1_at_k: Dict[int, float] # F1@k
|
||||
mrr: float # 平均倒数排名
|
||||
ndcg_at_k: Dict[int, float] # NDCG@k
|
||||
relevance_scores: List[float] # 每个测试用例的相关性评分
|
||||
|
||||
|
||||
class RAGEvaluator:
|
||||
"""RAG 评估器"""
|
||||
|
||||
def __init__(self, rag_pipeline, test_cases: List[RetrievalTestCase]):
|
||||
"""
|
||||
初始化评估器
|
||||
|
||||
Args:
|
||||
rag_pipeline: RAG 流水线对象(需实现 aretrieve 方法)
|
||||
test_cases: 测试用例列表
|
||||
"""
|
||||
self.rag_pipeline = rag_pipeline
|
||||
self.test_cases = test_cases
|
||||
|
||||
async def evaluate_retrieval(self, k_list: List[int] = None) -> RetrievalMetrics:
|
||||
"""
|
||||
评估检索质量
|
||||
|
||||
Args:
|
||||
k_list: 要计算的 k 值列表,例如 [1, 3, 5]
|
||||
|
||||
Returns:
|
||||
检索评估指标
|
||||
"""
|
||||
if k_list is None:
|
||||
k_list = [1, 3, 5, 10]
|
||||
|
||||
all_results = []
|
||||
all_mrr = []
|
||||
|
||||
for test_case in self.test_cases:
|
||||
# 执行检索
|
||||
retrieved_docs = await self.rag_pipeline.aretrieve(test_case.query)
|
||||
retrieved_ids = [doc.metadata.get("id", doc.page_content[:50]) for doc in retrieved_docs]
|
||||
|
||||
# 计算召回率和精确率
|
||||
result = self._calculate_retrieval_metrics(
|
||||
retrieved_ids,
|
||||
test_case.relevant_doc_ids,
|
||||
k_list
|
||||
)
|
||||
all_results.append(result)
|
||||
|
||||
# 计算 MRR
|
||||
mrr = self._calculate_mrr(retrieved_ids, test_case.relevant_doc_ids)
|
||||
all_mrr.append(mrr)
|
||||
|
||||
# 聚合所有测试用例的结果
|
||||
metrics = self._aggregate_metrics(all_results, all_mrr, k_list)
|
||||
return metrics
|
||||
|
||||
def _calculate_retrieval_metrics(
|
||||
self,
|
||||
retrieved_ids: List[str],
|
||||
relevant_ids: List[str],
|
||||
k_list: List[int]
|
||||
) -> Dict[int, Dict[str, float]]:
|
||||
"""
|
||||
计算单个测试用例的检索指标
|
||||
|
||||
Returns:
|
||||
{k: {'recall': float, 'precision': float, 'f1': float}}
|
||||
"""
|
||||
results = {}
|
||||
|
||||
for k in k_list:
|
||||
# 取前 k 个结果
|
||||
top_k = retrieved_ids[:k]
|
||||
|
||||
# 计算交集
|
||||
relevant_in_top_k = set(top_k) & set(relevant_ids)
|
||||
num_relevant_in_top_k = len(relevant_in_top_k)
|
||||
|
||||
# 召回率 = 相关文档在 top k 中的数量 / 总相关文档数量
|
||||
recall = num_relevant_in_top_k / len(relevant_ids) if relevant_ids else 0.0
|
||||
|
||||
# 精确率 = 相关文档在 top k 中的数量 / k
|
||||
precision = num_relevant_in_top_k / k if k > 0 else 0.0
|
||||
|
||||
# F1 分数
|
||||
f1 = 2 * recall * precision / (recall + precision) if (recall + precision) > 0 else 0.0
|
||||
|
||||
results[k] = {
|
||||
'recall': recall,
|
||||
'precision': precision,
|
||||
'f1': f1
|
||||
}
|
||||
|
||||
return results
|
||||
|
||||
def _calculate_mrr(self, retrieved_ids: List[str], relevant_ids: List[str]) -> float:
|
||||
"""
|
||||
计算平均倒数排名 (Mean Reciprocal Rank)
|
||||
|
||||
MRR@k = 1/m * sum(1/rank_i for i=1..m)
|
||||
其中 rank_i 是第 i 个相关文档第一次出现的排名
|
||||
"""
|
||||
for rank, doc_id in enumerate(retrieved_ids, start=1):
|
||||
if doc_id in relevant_ids:
|
||||
return 1.0 / rank
|
||||
return 0.0
|
||||
|
||||
def _calculate_ndcg(
|
||||
self,
|
||||
retrieved_ids: List[str],
|
||||
relevant_ids: List[str],
|
||||
k: int
|
||||
) -> float:
|
||||
"""
|
||||
计算 NDCG@k (Normalized Discounted Cumulative Gain)
|
||||
|
||||
DCG@k = sum(relevance_i / log2(i+1) for i=1..k)
|
||||
NDCG@k = DCG@k / IDCG@k
|
||||
"""
|
||||
top_k = retrieved_ids[:k]
|
||||
|
||||
# 计算 DCG
|
||||
dcg = 0.0
|
||||
for i, doc_id in enumerate(top_k, start=1):
|
||||
relevance = 1.0 if doc_id in relevant_ids else 0.0
|
||||
dcg += relevance / (i.bit_length() - 1) # log2(i)
|
||||
|
||||
# 计算 IDCG(理想 DCG)
|
||||
ideal_relevance = [1.0] * min(len(relevant_ids), k)
|
||||
idcg = 0.0
|
||||
for i, rel in enumerate(ideal_relevance, start=1):
|
||||
idcg += rel / (i.bit_length() - 1)
|
||||
|
||||
return dcg / idcg if idcg > 0 else 0.0
|
||||
|
||||
def _aggregate_metrics(
|
||||
self,
|
||||
all_results: List[Dict[int, Dict[str, float]]],
|
||||
all_mrr: List[float],
|
||||
k_list: List[int]
|
||||
) -> RetrievalMetrics:
|
||||
"""聚合所有测试用例的指标"""
|
||||
|
||||
recall_at_k = {}
|
||||
precision_at_k = {}
|
||||
f1_at_k = {}
|
||||
ndcg_at_k = {}
|
||||
|
||||
for k in k_list:
|
||||
# 聚合召回率
|
||||
recalls = [result[k]['recall'] for result in all_results]
|
||||
recall_at_k[k] = sum(recalls) / len(recalls)
|
||||
|
||||
# 聚合精确率
|
||||
precisions = [result[k]['precision'] for result in all_results]
|
||||
precision_at_k[k] = sum(precisions) / len(precisions)
|
||||
|
||||
# 聚合 F1
|
||||
f1s = [result[k]['f1'] for result in all_results]
|
||||
f1_at_k[k] = sum(f1s) / len(f1s)
|
||||
|
||||
# 计算 NDCG(这里简化处理)
|
||||
ndcg_at_k[k] = sum(f1s) / len(f1s) # 用 F1 近似
|
||||
|
||||
# 计算 MRR
|
||||
mrr = sum(all_mrr) / len(all_mrr)
|
||||
|
||||
return RetrievalMetrics(
|
||||
recall_at_k=recall_at_k,
|
||||
precision_at_k=precision_at_k,
|
||||
f1_at_k=f1_at_k,
|
||||
mrr=mrr,
|
||||
ndcg_at_k=ndcg_at_k,
|
||||
relevance_scores=[1.0] * len(all_results) # 占位符
|
||||
)
|
||||
|
||||
|
||||
class RelevanceEvaluator:
|
||||
"""相关性评估器(基于 LLM 评估)"""
|
||||
|
||||
def __init__(self, llm):
|
||||
"""
|
||||
初始化相关性评估器
|
||||
|
||||
Args:
|
||||
llm: 用于评估相关性的语言模型
|
||||
"""
|
||||
self.llm = llm
|
||||
|
||||
async def evaluate_relevance(
|
||||
self,
|
||||
query: str,
|
||||
document: Document
|
||||
) -> Tuple[float, str]:
|
||||
"""
|
||||
评估文档与查询的相关性
|
||||
|
||||
Args:
|
||||
query: 用户查询
|
||||
document: 文档对象
|
||||
|
||||
Returns:
|
||||
(相关性分数 0-5, 评估理由)
|
||||
"""
|
||||
prompt = f"""请评估以下文档与用户查询的相关性,给出 0-5 的评分:
|
||||
|
||||
用户查询:{query}
|
||||
|
||||
文档内容:{document.page_content[:500]}
|
||||
|
||||
请按以下标准评分:
|
||||
5 = 完全相关,文档直接回答了用户查询
|
||||
4 = 高度相关,文档包含回答查询的关键信息
|
||||
3 = 部分相关,文档有一些相关信息但不够直接
|
||||
2 = 弱相关,文档有少量提及但不太相关
|
||||
1 = 不相关,文档内容与查询基本无关
|
||||
0 = 完全无关
|
||||
|
||||
请只返回 JSON 格式,例如:{{"score": 4, "reason": "文档详细解释了用户查询的概念"}}"""
|
||||
|
||||
try:
|
||||
response = await self.llm.ainvoke(prompt)
|
||||
result_text = response.content if hasattr(response, 'content') else str(response)
|
||||
|
||||
# 尝试解析 JSON
|
||||
import json
|
||||
result = json.loads(result_text)
|
||||
score = result.get('score', 0.0)
|
||||
reason = result.get('reason', '无理由')
|
||||
|
||||
# 确保分数在 0-5 范围内
|
||||
score = max(0.0, min(5.0, float(score)))
|
||||
|
||||
return score, reason
|
||||
|
||||
except Exception as e:
|
||||
return 0.0, f"评估失败:{str(e)}"
|
||||
|
||||
|
||||
def generate_test_report(metrics: RetrievalMetrics) -> str:
|
||||
"""生成测试报告"""
|
||||
|
||||
report = []
|
||||
report.append("=" * 80)
|
||||
report.append("RAG 系统评估报告")
|
||||
report.append("=" * 80)
|
||||
report.append("")
|
||||
|
||||
# 召回率
|
||||
report.append("【召回率 Recall@k】")
|
||||
for k, v in sorted(metrics.recall_at_k.items()):
|
||||
report.append(f" Recall@{k}: {v:.2%}")
|
||||
report.append("")
|
||||
|
||||
# 精确率
|
||||
report.append("【精确率 Precision@k】")
|
||||
for k, v in sorted(metrics.precision_at_k.items()):
|
||||
report.append(f" Precision@{k}: {v:.2%}")
|
||||
report.append("")
|
||||
|
||||
# F1 分数
|
||||
report.append("【F1 分数 F1@k】")
|
||||
for k, v in sorted(metrics.f1_at_k.items()):
|
||||
report.append(f" F1@{k}: {v:.4f}")
|
||||
report.append("")
|
||||
|
||||
# MRR
|
||||
report.append(f"【平均倒数排名 MRR】: {metrics.mrr:.4f}")
|
||||
report.append("")
|
||||
|
||||
# 解释
|
||||
report.append("=" * 80)
|
||||
report.append("指标说明:")
|
||||
report.append("- Recall@k: 前 k 个结果中包含多少比例的相关文档")
|
||||
report.append("- Precision@k: 前 k 个结果中有多少比例是相关文档")
|
||||
report.append("- F1@k: 召回率和精确率的调和平均数")
|
||||
report.append("- MRR: 第一个相关文档的排名的倒数的平均值")
|
||||
report.append("=" * 80)
|
||||
|
||||
return "\n".join(report)
|
||||
|
||||
|
||||
# 示例使用
|
||||
def create_sample_test_cases() -> List[RetrievalTestCase]:
|
||||
"""创建示例测试用例"""
|
||||
return [
|
||||
RetrievalTestCase(
|
||||
query="什么是 RAG 系统?",
|
||||
relevant_doc_ids=["doc_rag_1", "doc_rag_2"],
|
||||
expected_answer="RAG 是 Retrieval-Augmented Generation 的缩写..."
|
||||
),
|
||||
RetrievalTestCase(
|
||||
query="如何使用 LangChain?",
|
||||
relevant_doc_ids=["doc_langchain_1", "doc_langchain_2", "doc_langchain_3"],
|
||||
expected_answer="LangChain 的使用步骤包括..."
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 示例:如何使用评估器
|
||||
print("RAG 评估模块已加载")
|
||||
print("使用方法:")
|
||||
print(" 1. 创建测试用例")
|
||||
print(" 2. 初始化 RAGEvaluator")
|
||||
print(" 3. 调用 evaluate_retrieval()")
|
||||
print(" 4. 生成报告")
|
||||
@@ -1,137 +1,114 @@
|
||||
"""
|
||||
RAG 检索流水线模块
|
||||
|
||||
提供固定流程的 RAG 检索:
|
||||
多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档
|
||||
|
||||
默认使用混合检索(稠密+稀疏)+ 父子文档模式。
|
||||
RAG 检索流水线
|
||||
流程: 检索子文档 → 重排 → 获取父文档 → 返回
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import List, Optional
|
||||
import logging
|
||||
from typing import List
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
|
||||
from app.model_services import get_rerank_service, get_small_llm_service
|
||||
from app.rag.rerank import create_document_reranker
|
||||
from app.rag.query_transform import MultiQueryGenerator
|
||||
from app.rag.fusion import reciprocal_rank_fusion
|
||||
from app.rag.retriever import create_parent_hybrid_retriever
|
||||
from ..model_services import get_rerank_service, get_small_llm_service
|
||||
from ..rag.rerank import create_document_reranker
|
||||
from ..rag.query_transform import MultiQueryGenerator
|
||||
from ..rag.fusion import reciprocal_rank_fusion
|
||||
from ..rag.retriever import create_parent_hybrid_retriever
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RAGPipeline:
|
||||
"""
|
||||
固定流程的 RAG 检索流水线:
|
||||
多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档
|
||||
|
||||
默认使用混合检索(稠密+BM25稀疏)+ 父子文档模式。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
retriever=None,
|
||||
llm: Optional[BaseLanguageModel] = "default_small",
|
||||
llm: BaseLanguageModel | str = "default_small",
|
||||
num_queries: int = 3,
|
||||
rerank_top_n: int = 5,
|
||||
collection_name: str = "rag_documents",
|
||||
use_rerank: bool = True,
|
||||
return_parent_docs: bool = True,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
retriever: 基础检索器对象,需实现 ainvoke(query) 异步方法。
|
||||
如果不提供,会自动创建默认的父子文档混合检索器。
|
||||
llm: 用于生成多路查询的语言模型。
|
||||
- "default_small": (默认) 使用小模型(本地 + DeepSeek)
|
||||
- None / False: 不做查询改写
|
||||
- BaseLanguageModel 实例: 自定义模型
|
||||
num_queries: 生成的查询变体数量。
|
||||
rerank_top_n: 最终返回的文档数量。
|
||||
collection_name: Qdrant 集合名称(仅当 retriever 未提供时使用)。
|
||||
"""
|
||||
# 如果没有提供 retriever,自动创建默认的混合检索器
|
||||
if retriever is None:
|
||||
self.retriever = create_parent_hybrid_retriever(
|
||||
collection_name=collection_name,
|
||||
search_k=rerank_top_n * 2 # 多取一些给重排序用
|
||||
)
|
||||
else:
|
||||
self.retriever = retriever
|
||||
|
||||
# 处理 llm 参数
|
||||
self.retriever = retriever or create_parent_hybrid_retriever(
|
||||
collection_name=collection_name, search_k=rerank_top_n * 4
|
||||
)
|
||||
self.num_queries = num_queries
|
||||
self.rerank_top_n = rerank_top_n
|
||||
self.use_rerank = use_rerank
|
||||
self.return_parent_docs = return_parent_docs
|
||||
|
||||
if llm == "default_small":
|
||||
try:
|
||||
self.llm = get_small_llm_service()
|
||||
except Exception as e:
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(f"小模型初始化失败,将不做查询改写: {e}")
|
||||
except Exception:
|
||||
self.llm = None
|
||||
elif llm in (None, False):
|
||||
self.llm = None
|
||||
else:
|
||||
self.llm = llm
|
||||
|
||||
self.num_queries = num_queries
|
||||
self.rerank_top_n = rerank_top_n
|
||||
|
||||
# 初始化组件 - 使用统一的重排服务获取接口
|
||||
self.query_generator = MultiQueryGenerator(llm=self.llm, num_queries=num_queries) if self.llm else None
|
||||
self.reranker = create_document_reranker()
|
||||
|
||||
self.llm = llm if llm else None
|
||||
|
||||
self.query_generator = MultiQueryGenerator(self.llm, num_queries) if self.llm else None
|
||||
self.reranker = create_document_reranker() if use_rerank else None
|
||||
logger.info(f"[Pipeline] init: rerank={use_rerank}, return_parent={return_parent_docs}")
|
||||
|
||||
async def aretrieve(self, query: str) -> List[Document]:
|
||||
"""
|
||||
异步执行完整检索流程
|
||||
|
||||
Args:
|
||||
query: 用户查询
|
||||
|
||||
Returns:
|
||||
检索到的相关文档列表
|
||||
"""
|
||||
# 如果有 query_generator,做多路改写
|
||||
if self.query_generator and self.llm:
|
||||
# Step 1: 生成多路查询
|
||||
# Step 1: 检索
|
||||
child_docs = await self._retrieve(query)
|
||||
logger.info(f"[Pipeline] 检索到 {len(child_docs)} 个子文档")
|
||||
# 调试:打印子文档长度
|
||||
for i, doc in enumerate(child_docs[:5]):
|
||||
content_len = len(doc.page_content)
|
||||
logger.info(f"[Pipeline] 子文档[{i}] 长度={content_len}字符")
|
||||
|
||||
# Step 2: 重排
|
||||
if self.reranker:
|
||||
try:
|
||||
child_docs = self.reranker.compress_documents(child_docs, query, self.rerank_top_n)
|
||||
logger.info(f"[Pipeline] 重排后 {len(child_docs)} 个")
|
||||
except Exception as e:
|
||||
logger.warning(f"[Pipeline] 重排失败: {e}")
|
||||
child_docs = child_docs[:self.rerank_top_n]
|
||||
|
||||
# Step 3: 获取父文档
|
||||
if self.return_parent_docs:
|
||||
return await self._get_parents(child_docs)
|
||||
return child_docs
|
||||
|
||||
async def _retrieve(self, query: str) -> List[Document]:
|
||||
if self.query_generator:
|
||||
queries = await self.query_generator.agenerate(query)
|
||||
# 包含原始查询,确保至少有一条
|
||||
if query not in queries:
|
||||
queries.insert(0, query)
|
||||
else:
|
||||
# 如果原始查询已在列表中,将其移至首位
|
||||
queries.remove(query)
|
||||
queries.insert(0, query)
|
||||
|
||||
# Step 2: 并行检索(每个查询获取文档列表)
|
||||
tasks = [self.retriever.ainvoke(q) for q in queries]
|
||||
doc_lists = await asyncio.gather(*tasks)
|
||||
|
||||
# Step 3: RRF 融合
|
||||
fused_docs = reciprocal_rank_fusion(doc_lists)
|
||||
else:
|
||||
# 没有 LLM 做查询改写,直接用原始查询检索
|
||||
fused_docs = await self.retriever.ainvoke(query)
|
||||
|
||||
# Step 4: 重排序
|
||||
queries = [query] + [q for q in queries if q != query]
|
||||
doc_lists = await asyncio.gather(*[self.retriever.ainvoke(q) for q in queries])
|
||||
return reciprocal_rank_fusion(doc_lists)
|
||||
return await self.retriever.ainvoke(query)
|
||||
|
||||
async def _get_parents(self, child_docs: List[Document]) -> List[Document]:
|
||||
parent_map = {}
|
||||
for doc in child_docs:
|
||||
pid = doc.metadata.get("parent_id")
|
||||
if pid and pid not in parent_map:
|
||||
parent_map[pid] = doc.metadata.get("score", 0.0)
|
||||
|
||||
if not parent_map:
|
||||
logger.warning("[Pipeline] 未找到 parent_id,返回子文档")
|
||||
return child_docs
|
||||
|
||||
try:
|
||||
final_docs = self.reranker.compress_documents(fused_docs, query, top_n=self.rerank_top_n)
|
||||
except Exception:
|
||||
# 若重排序器不可用,直接返回融合后的前 N 个结果
|
||||
final_docs = fused_docs[:self.rerank_top_n]
|
||||
|
||||
return final_docs
|
||||
from backend.rag_core import create_docstore
|
||||
docstore, _ = create_docstore()
|
||||
# 同步获取(异步版本不存在)
|
||||
parent_docs = docstore.mget(list(parent_map.keys()))
|
||||
parent_map2 = {d.metadata.get("id"): d for d in parent_docs if d}
|
||||
result = [(parent_map2[pid], score) for pid, score in parent_map.items() if pid in parent_map2]
|
||||
result.sort(key=lambda x: x[1], reverse=True)
|
||||
docs = [d for d, _ in result]
|
||||
logger.info(f"[Pipeline] 获取到 {len(docs)} 个父文档")
|
||||
return docs
|
||||
except Exception as e:
|
||||
logger.warning(f"[Pipeline] 获取父文档失败: {e}")
|
||||
return child_docs
|
||||
|
||||
def format_context(self, documents: List[Document]) -> str:
|
||||
"""
|
||||
将文档列表格式化为上下文字符串
|
||||
|
||||
Args:
|
||||
documents: 文档列表
|
||||
|
||||
Returns:
|
||||
格式化后的上下文字符串
|
||||
"""
|
||||
if not documents:
|
||||
return ""
|
||||
|
||||
parts = []
|
||||
for i, doc in enumerate(documents, 1):
|
||||
source = doc.metadata.get("source", "未知来源")
|
||||
@@ -139,30 +116,5 @@ class RAGPipeline:
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
def create_rag_pipeline(
|
||||
collection_name: str = "rag_documents",
|
||||
llm: Optional[BaseLanguageModel] = "default_small",
|
||||
num_queries: int = 3,
|
||||
rerank_top_n: int = 5,
|
||||
) -> RAGPipeline:
|
||||
"""
|
||||
创建 RAG 检索流水线的便捷函数
|
||||
|
||||
Args:
|
||||
collection_name: Qdrant 集合名称
|
||||
llm: 用于生成多路查询的语言模型。
|
||||
- "default_small": (默认) 使用小模型(本地 + DeepSeek)
|
||||
- None / False: 不做查询改写
|
||||
- BaseLanguageModel 实例: 自定义模型
|
||||
num_queries: 生成的查询变体数量
|
||||
rerank_top_n: 最终返回的文档数量
|
||||
|
||||
Returns:
|
||||
RAGPipeline 实例
|
||||
"""
|
||||
return RAGPipeline(
|
||||
llm=llm,
|
||||
num_queries=num_queries,
|
||||
rerank_top_n=rerank_top_n,
|
||||
collection_name=collection_name
|
||||
)
|
||||
def create_rag_pipeline(**kwargs) -> RAGPipeline:
|
||||
return RAGPipeline(**kwargs)
|
||||
|
||||
@@ -57,14 +57,26 @@ class DocumentReranker:
|
||||
try:
|
||||
# 1. 从 Document 提取内容(业务逻辑)
|
||||
doc_contents = [doc.page_content for doc in documents]
|
||||
logger.info(f"[Rerank] 收到 {len(documents)} 个文档待重排, query={query[:50]}")
|
||||
total_chars = sum(len(c) for c in doc_contents)
|
||||
logger.info(f"[Rerank] 各文档长度: {[len(c) for c in doc_contents]}, 总字符数: {total_chars}")
|
||||
# 粗略估算 tokens (中文约 0.75 tokens/字符)
|
||||
estimated_tokens = int(total_chars * 0.75)
|
||||
logger.info(f"[Rerank] 估算总 tokens: ~{estimated_tokens} (假设中文)")
|
||||
|
||||
# 2. 调用纯服务层计算得分
|
||||
logger.info(f"[Rerank] 正在调用 rerank service: {type(self._rerank_service).__name__}")
|
||||
scores = self._rerank_service.compute_scores(query, doc_contents)
|
||||
logger.info(f"[Rerank] 获取到 {len(scores)} 个得分: {scores}")
|
||||
|
||||
# 3. 根据得分排序(业务逻辑)
|
||||
doc_score_pairs = list(zip(documents, scores))
|
||||
doc_score_pairs_sorted = sorted(doc_score_pairs, key=lambda x: x[1], reverse=True)
|
||||
|
||||
logger.info(f"[Rerank] 排序后的结果:")
|
||||
for i, (doc, score) in enumerate(doc_score_pairs_sorted):
|
||||
logger.info(f" [{i}] score={score:.4f}, content={doc.page_content[:80]}...")
|
||||
|
||||
# 4. 取 top_n
|
||||
top_docs = [pair[0] for pair in doc_score_pairs_sorted[:top_n]]
|
||||
|
||||
@@ -72,6 +84,9 @@ class DocumentReranker:
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"重排过程出错,返回原始前 {top_n} 个结果: {e}")
|
||||
logger.warning(f"[Rerank] 异常详情: {type(e).__name__}: {e}")
|
||||
import traceback
|
||||
logger.warning(f"[Rerank] 堆栈: {traceback.format_exc()}")
|
||||
return documents[:top_n]
|
||||
|
||||
|
||||
|
||||
@@ -19,10 +19,10 @@ from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from pydantic import Field, PrivateAttr
|
||||
|
||||
from rag_core import QdrantHybridStore, get_sparse_embedder, create_docstore
|
||||
from rag_core.client import create_async_qdrant_client
|
||||
from app.model_services import get_embedding_service
|
||||
from app.logger import info, warning, debug
|
||||
from backend.rag_core import QdrantHybridStore, get_sparse_embedder, create_docstore
|
||||
from backend.rag_core.client import create_async_qdrant_client
|
||||
from ..model_services import get_embedding_service
|
||||
from ..logger import info, warning, debug
|
||||
|
||||
|
||||
# 模块级常量
|
||||
@@ -131,20 +131,20 @@ class HybridRetriever(BaseRetriever):
|
||||
class ParentHybridRetriever(BaseRetriever):
|
||||
"""
|
||||
父子文档混合检索器(异步):
|
||||
|
||||
|
||||
1. 先用混合检索找到相关子文档
|
||||
2. 根据子文档的 parent_id 找到对应的父文档
|
||||
3. 去重并返回父文档
|
||||
"""
|
||||
|
||||
|
||||
collection_name: str = Field(description="Qdrant 集合名称")
|
||||
search_k: int = Field(default=DEFAULT_PARENT_SEARCH_K, description="检索返回结果数")
|
||||
|
||||
|
||||
_vector_store: Any = PrivateAttr()
|
||||
_client: Any = PrivateAttr()
|
||||
_sparse_embedder: Any = PrivateAttr()
|
||||
_docstore: Any = PrivateAttr()
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: str,
|
||||
@@ -188,7 +188,7 @@ class ParentHybridRetriever(BaseRetriever):
|
||||
self, query: str, *, run_manager: Any = None
|
||||
) -> List[Document]:
|
||||
"""
|
||||
异步检索相关父文档
|
||||
异步检索相关子文档
|
||||
"""
|
||||
# 1. 生成查询向量
|
||||
dense_query = await self._vector_store.aembed_query(query)
|
||||
@@ -197,10 +197,10 @@ class ParentHybridRetriever(BaseRetriever):
|
||||
indices=sparse_query["indices"],
|
||||
values=sparse_query["values"]
|
||||
)
|
||||
|
||||
|
||||
# 2. 多取一些子文档,避免去重后数量不足
|
||||
search_limit = self.search_k * 2
|
||||
|
||||
|
||||
# 3. 使用 query_points API 进行混合检索
|
||||
response = await self._client.query_points(
|
||||
collection_name=self.collection_name,
|
||||
@@ -220,87 +220,27 @@ class ParentHybridRetriever(BaseRetriever):
|
||||
limit=search_limit,
|
||||
with_payload=True
|
||||
)
|
||||
|
||||
|
||||
if not response.points:
|
||||
debug("混合检索未找到任何文档")
|
||||
return []
|
||||
|
||||
# 4. 收集 parent_id 和对应最高得分
|
||||
parent_score_map = {}
|
||||
parent_ids = set()
|
||||
child_point_map = {} # 保存子文档点用于降级
|
||||
|
||||
|
||||
# 4. 构建子文档列表
|
||||
child_docs = []
|
||||
for point in response.points:
|
||||
payload_copy = point.payload.copy()
|
||||
parent_id = payload_copy.get("parent_id", point.id)
|
||||
score = point.score
|
||||
|
||||
if parent_id not in parent_score_map or score > parent_score_map[parent_id]:
|
||||
parent_score_map[parent_id] = score
|
||||
parent_ids.add(parent_id)
|
||||
child_point_map[parent_id] = point
|
||||
|
||||
# 5. 批量查询父文档
|
||||
parent_docs = []
|
||||
found_parent_ids = set()
|
||||
|
||||
# 先尝试从 Qdrant 直接查询(如果父文档也在 Qdrant 中)
|
||||
try:
|
||||
parent_points = await self._client.retrieve(
|
||||
collection_name=self.collection_name,
|
||||
ids=list(parent_ids),
|
||||
with_payload=True
|
||||
doc = Document(
|
||||
page_content=payload_copy.pop("page_content", payload_copy.pop("text", "")),
|
||||
metadata={
|
||||
**payload_copy,
|
||||
"child_id": point.id,
|
||||
"score": point.score
|
||||
}
|
||||
)
|
||||
|
||||
for point in parent_points:
|
||||
payload_copy = point.payload.copy()
|
||||
doc = Document(
|
||||
page_content=payload_copy.pop("page_content", payload_copy.pop("text", "")),
|
||||
metadata=payload_copy
|
||||
)
|
||||
parent_docs.append(doc)
|
||||
found_parent_ids.add(point.id)
|
||||
|
||||
except Exception as e:
|
||||
warning(f"从 Qdrant 查询父文档失败: {e}")
|
||||
|
||||
# 6. 如果有 docstore,尝试从 docstore 查询剩余的父文档
|
||||
if self._docstore and len(found_parent_ids) < len(parent_ids):
|
||||
missing_parent_ids = parent_ids - found_parent_ids
|
||||
try:
|
||||
docstore_docs = await self._docstore.amget(missing_parent_ids)
|
||||
for doc_id, doc in zip(missing_parent_ids, docstore_docs):
|
||||
if doc is not None:
|
||||
parent_docs.append(doc)
|
||||
found_parent_ids.add(doc_id)
|
||||
except Exception as e:
|
||||
warning(f"从 docstore 查询父文档失败: {e}")
|
||||
|
||||
# 7. 降级:对于仍未找到的父文档,用子文档本身代替
|
||||
missing_parent_ids = parent_ids - found_parent_ids
|
||||
if missing_parent_ids:
|
||||
warning(f"以下 parent_id 未找到对应的父文档,将返回子文档本身: {missing_parent_ids}")
|
||||
for parent_id in missing_parent_ids:
|
||||
child_point = child_point_map.get(parent_id)
|
||||
if child_point:
|
||||
payload_copy = child_point.payload.copy()
|
||||
doc = Document(
|
||||
page_content=payload_copy.pop("page_content", payload_copy.pop("text", "")),
|
||||
metadata=payload_copy
|
||||
)
|
||||
parent_docs.append(doc)
|
||||
|
||||
# 8. 按照得分降序排序,返回前 k 个
|
||||
parent_docs_with_scores = [
|
||||
(doc, parent_score_map.get(doc.metadata.get("id", doc.id if hasattr(doc, "id") else ""), 0.0))
|
||||
for doc in parent_docs
|
||||
]
|
||||
parent_docs_with_scores.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
final_docs = [doc for doc, _ in parent_docs_with_scores[:self.search_k]]
|
||||
debug(f"父子文档混合检索返回 {len(final_docs)} 个父文档")
|
||||
|
||||
return final_docs
|
||||
child_docs.append(doc)
|
||||
|
||||
debug(f"父子文档混合检索返回 {len(child_docs)} 个子文档")
|
||||
return child_docs
|
||||
|
||||
|
||||
def create_hybrid_retriever(
|
||||
|
||||
@@ -10,7 +10,7 @@ from typing import Callable, Optional
|
||||
from langchain_core.tools import tool
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from app.rag.pipeline import RAGPipeline, create_rag_pipeline
|
||||
from ..rag.pipeline import RAGPipeline, create_rag_pipeline
|
||||
|
||||
|
||||
def create_rag_tool(
|
||||
|
||||
@@ -20,22 +20,18 @@ class ContactAPIClient:
|
||||
|
||||
def __init__(self, conn=None):
|
||||
"""
|
||||
初始化
|
||||
|
||||
初始化(使用延迟初始化,MCP在首次使用时初始化)
|
||||
|
||||
Args:
|
||||
conn: 数据库连接(保留用于向后兼容)
|
||||
"""
|
||||
self.conn = conn
|
||||
|
||||
# 确保MCP已初始化
|
||||
import asyncio
|
||||
try:
|
||||
asyncio.create_task(self._init_mcp())
|
||||
except RuntimeError:
|
||||
pass # 没有事件循环时跳过,延迟初始化
|
||||
self._mcp_initialized = False # 延迟初始化标志
|
||||
|
||||
async def _init_mcp(self):
|
||||
"""初始化MCP系统"""
|
||||
"""初始化MCP系统(延迟初始化)"""
|
||||
if self._mcp_initialized:
|
||||
return # 已初始化,跳过
|
||||
if not mcp_manager.get_adapter("contact"):
|
||||
# 获取repository(如果有)
|
||||
repo = None
|
||||
@@ -45,9 +41,10 @@ class ContactAPIClient:
|
||||
repo = ContactRepository(self.conn)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
mcp_manager.register_adapter(ContactAdapter(contact_repo=repo))
|
||||
await mcp_manager.initialize()
|
||||
self._mcp_initialized = True
|
||||
|
||||
async def list_contacts(self, user_id: str = "default") -> List[Contact]:
|
||||
"""获取联系人列表"""
|
||||
|
||||
@@ -8,7 +8,7 @@ from typing import Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
# 公共工具
|
||||
from app.core import MarkdownFormatter
|
||||
from ...core import MarkdownFormatter
|
||||
|
||||
from .state import ContactState
|
||||
from .api_client import ContactAPIClient
|
||||
|
||||
@@ -22,20 +22,19 @@ class DictionaryAPIClient:
|
||||
word_repository: Optional[Any] = None
|
||||
|
||||
def __post_init__(self):
|
||||
"""初始化后设置MCP"""
|
||||
import asyncio
|
||||
try:
|
||||
asyncio.create_task(self._init_mcp())
|
||||
except RuntimeError:
|
||||
pass
|
||||
"""初始化后设置(延迟初始化标志)"""
|
||||
self._mcp_initialized = False # 延迟初始化标志
|
||||
|
||||
async def _init_mcp(self):
|
||||
"""初始化MCP系统"""
|
||||
"""初始化MCP系统(延迟初始化)"""
|
||||
if self._mcp_initialized:
|
||||
return # 已初始化,跳过
|
||||
if not mcp_manager.get_adapter("dictionary"):
|
||||
mcp_manager.register_adapter(
|
||||
DictionaryAdapter(word_repo=self.word_repository)
|
||||
)
|
||||
await mcp_manager.initialize()
|
||||
self._mcp_initialized = True
|
||||
|
||||
async def query_word(
|
||||
self,
|
||||
|
||||
@@ -8,7 +8,7 @@ from datetime import datetime
|
||||
import random
|
||||
|
||||
# 公共工具
|
||||
from app.core import (
|
||||
from ...core import (
|
||||
MarkdownFormatter
|
||||
)
|
||||
|
||||
|
||||
@@ -23,20 +23,19 @@ class NewsAPIClient:
|
||||
news_repository: Optional[Any] = None
|
||||
|
||||
def __post_init__(self):
|
||||
"""初始化后设置MCP"""
|
||||
import asyncio
|
||||
try:
|
||||
asyncio.create_task(self._init_mcp())
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
"""初始化后设置(延迟初始化标志)"""
|
||||
self._mcp_initialized = False # 延迟初始化标志
|
||||
|
||||
async def _init_mcp(self):
|
||||
"""初始化MCP系统"""
|
||||
"""初始化MCP系统(延迟初始化)"""
|
||||
if self._mcp_initialized:
|
||||
return # 已初始化,跳过
|
||||
if not mcp_manager.get_adapter("news"):
|
||||
mcp_manager.register_adapter(
|
||||
NewsAdapter(news_repo=self.news_repository)
|
||||
)
|
||||
await mcp_manager.initialize()
|
||||
self._mcp_initialized = True
|
||||
|
||||
async def query_news(
|
||||
self,
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
# 公共工具
|
||||
from app.core import MarkdownFormatter
|
||||
from ...core import MarkdownFormatter
|
||||
|
||||
from .state import (
|
||||
NewsAnalysisState,
|
||||
|
||||
@@ -3,9 +3,9 @@ LangGraph 节点日志工具模块
|
||||
提供状态流转追踪和 LLM 输入输出打印功能
|
||||
"""
|
||||
|
||||
from app.config import ENABLE_GRAPH_TRACE
|
||||
from app.logger import debug, info
|
||||
from app.main_graph.state import MainGraphState
|
||||
from ..config import ENABLE_GRAPH_TRACE
|
||||
from ..logger import debug, info
|
||||
from ..main_graph.state import MainGraphState
|
||||
|
||||
|
||||
def log_state_change(node_name: str, state: MainGraphState, prefix: str = "进入"):
|
||||
@@ -17,7 +17,7 @@ def log_state_change(node_name: str, state: MainGraphState, prefix: str = "进
|
||||
state: 当前状态
|
||||
prefix: 日志前缀("进入" 或 "离开")
|
||||
"""
|
||||
from app.logger import info
|
||||
from backend.app.logger import info
|
||||
|
||||
messages = state.messages
|
||||
msg_count = len(messages)
|
||||
|
||||
Reference in New Issue
Block a user