导入方式修改
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 6m44s

This commit is contained in:
2026-05-05 23:17:00 +08:00
parent b5c15ef445
commit 3ae9daa01a
51 changed files with 445 additions and 532 deletions

View File

@@ -1154,7 +1154,7 @@ RAG 系统分为两个独立但协同的阶段:
❌ 只能捕捉"语义相似",专有名词匹配差
实现代码:
from app.rag.retriever import create_base_retriever
from backend.app.rag.retriever import create_base_retriever
retriever = create_base_retriever(
collection_name="rag_documents",
@@ -1175,7 +1175,7 @@ RAG 系统分为两个独立但协同的阶段:
两路结果并行获取,等待融合
实现代码:
from app.rag.retriever import create_hybrid_retriever
from backend.app.rag.retriever import create_hybrid_retriever
retriever = create_hybrid_retriever(
collection_name="rag_documents",
@@ -1195,7 +1195,7 @@ RAG 系统分为两个独立但协同的阶段:
由模型直接输出 0~1 的相关性得分,精度极高
实现代码:
from app.rag.reranker import LLaMaCPPReranker
from backend.app.rag.reranker import LLaMaCPPReranker
reranker = LLaMaCPPReranker(
base_url="http://127.0.0.1:8083",
@@ -1215,7 +1215,7 @@ RAG 系统分为两个独立但协同的阶段:
通过 LLM 将单一问题改写为多个不同角度的查询
实现代码:
from app.rag.query_transform import MultiQueryGenerator
from backend.app.rag.query_transform import MultiQueryGenerator
generator = MultiQueryGenerator(llm=llm, num_queries=3)
queries = await generator.agenerate("如何申请项目资金?")
@@ -1231,7 +1231,7 @@ RAG 系统分为两个独立但协同的阶段:
有效避免某一极端检索结果主导全局
实现代码:
from app.rag.fusion import reciprocal_rank_fusion
from backend.app.rag.fusion import reciprocal_rank_fusion
# 多个查询的检索结果
doc_lists = [result1, result2, result3]
@@ -1260,8 +1260,8 @@ RAG 系统分为两个独立但协同的阶段:
└────────── └────────────── └──────────┘ └────────┘
实现代码:
from app.rag.tools import search_knowledge_base
from app.main_graph.utils.main_graph_builder import MainGraphBuilder
from backend.app.rag.tools import search_knowledge_base
from backend.app.main_graph.utils.main_graph_builder import MainGraphBuilder
# 构建图
builder = MainGraphBuilder()

View File

@@ -3,6 +3,6 @@ AI Agent 应用模块
"""
from .agent.agent_service import AIAgentService
from app.main_graph.tools.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
from .main_graph.tools.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
__all__ = ["AIAgentService", "AVAILABLE_TOOLS", "TOOLS_BY_NAME"]

View File

@@ -7,6 +7,9 @@ import json
import asyncio
from typing import AsyncGenerator, Dict, Any, Optional, Tuple
# LangGraph 序列化器(修复 checkpoint 反序列化警告)
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
# 本地模块
from ..model_services import get_cached_chat_services
from ..main_graph.main_graph_builder import build_react_main_graph
@@ -18,6 +21,23 @@ from ..logger import debug, info, warning, error
from ..main_graph.state import MainGraphState, CurrentAction
# ========== 自定义类型序列化器 ==========
def create_serde() -> JsonPlusSerializer:
"""创建带自定义类型注册的序列化器"""
from backend.app.core.intent import ReasoningAction, RetrievalConfig, ReasoningResult
return JsonPlusSerializer(
allowed_msgpack_modules=[
("app.core.intent", "ReasoningAction"),
("app.core.intent", "RetrievalConfig"),
("app.core.intent", "ReasoningResult"),
("app.main_graph.state", "CurrentAction"),
("app.main_graph.state", "ErrorSeverity"),
("app.main_graph.state", "ErrorRecord"),
]
)
class AIAgentService:
def __init__(self, checkpointer):
self.checkpointer = checkpointer
@@ -55,6 +75,7 @@ class AIAgentService:
tools=self.tools,
mem0_client=self.mem0_client
)
# 注意serde 已在创建 checkpointer 时传入,这里只需传入 checkpointer
self.graph = graph_builder.compile(checkpointer=self.checkpointer)
info(f"✅ 单图初始化完成")

View File

@@ -4,7 +4,7 @@
"""
from typing import List, Dict, Any
from app.logger import error # 保持兼容,或者替换为 logger
from ..logger import error # 保持兼容,或者替换为 logger
class ThreadHistoryService:
"""线程历史查询服务"""

View File

@@ -3,8 +3,13 @@ FastAPI 后端 - 支持动态模型切换,使用 PostgreSQL 持久化记忆
采用依赖注入模式,优雅管理资源生命周期
"""
import warnings
# 抑制 WebSocket 弃用警告websockets 库升级导致uvicorn 尚未跟进)
warnings.filterwarnings("ignore", category=DeprecationWarning, module="websockets")
warnings.filterwarnings("ignore", category=DeprecationWarning, module="uvicorn.protocols.websockets")
import os
from app.config import DB_URI, BACKEND_PORT
from ..config import DB_URI, BACKEND_PORT
import uuid
import json
from contextlib import asynccontextmanager
@@ -15,26 +20,26 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from .agent.agent_service import AIAgentService
from .agent.agent_service import AIAgentService, create_serde
from .agent.history import ThreadHistoryService
from app.core.human_review import (
from ..core.human_review import (
ReviewManager,
InMemoryReviewStore,
ReviewStatus,
HumanReview
)
from app.subgraphs.contact.api_client import ContactAPIClient
from app.subgraphs.dictionary.api_client import DictionaryAPIClient
from app.subgraphs.news_analysis.api_client import NewsAPIClient
from ..subgraphs.contact.api_client import ContactAPIClient
from ..subgraphs.dictionary.api_client import DictionaryAPIClient
from ..subgraphs.news_analysis.api_client import NewsAPIClient
from .db.init_db import init_subgraph_tables
from .db.models import ContactRepository, DictionaryRepository, NewsRepository
from app.logger import info, error
from ..logger import info, error
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用生命周期管理:创建并注入全局服务"""
# 1. 创建数据库连接池并初始化表(仅 checkpointer
async with AsyncPostgresSaver.from_conn_string(DB_URI) as checkpointer:
async with AsyncPostgresSaver.from_conn_string(DB_URI, serde=create_serde()) as checkpointer:
await checkpointer.setup()
# 1.5 初始化子图表

View File

@@ -30,7 +30,7 @@ from .visualization import (
# 为了兼容性,添加 classify_intent 函数
def classify_intent(user_input: str, context: str = None):
"""兼容旧代码的 classify_intent 函数"""
from app.core.intent_classifier import get_intent_classifier
from backend.app.core.intent_classifier import get_intent_classifier
import asyncio
classifier = get_intent_classifier()
try:

View File

@@ -94,7 +94,7 @@ class ReactIntentReasoner:
def _get_llm_service(self):
"""懒加载 LLM 服务(避免循环导入)"""
if self._llm_service is None:
from app.model_services.chat_services import get_chat_service, get_small_llm_service
from backend.app.model_services.chat_services import get_chat_service, get_small_llm_service
if self._use_small_llm:
self._llm_service = get_small_llm_service()
else:

View File

@@ -89,7 +89,7 @@ class WebSearchTool:
def _search_tavily(self, query: str, max_results: int) -> List[SearchResult]:
"""使用 Tavily API 搜索"""
from tavily import TavilyClient
from app.config import TAVILY_API_KEY, TAVILY_MAX_RESULTS
from backend.app.config import TAVILY_API_KEY, TAVILY_MAX_RESULTS
if not TAVILY_API_KEY:
raise ValueError("TAVILY_API_KEY 未配置")

View File

@@ -4,7 +4,7 @@
"""
import os
from app.config import LOG_LEVEL, DEBUG
from .config import LOG_LEVEL, DEBUG
import logging
from typing import Any
# 根据环境变量控制是否显示详细调试信息

View File

@@ -4,12 +4,13 @@
"""
from typing import Dict, Any, Optional
from langchain_core.runnables.config import RunnableConfig
async def dispatch_custom_event(
event_name: str,
data: Dict[str, Any],
config: Optional[Dict[str, Any]] = None,
config: Optional[RunnableConfig] = None,
) -> None:
"""
安全地发送自定义事件,忽略发送失败

View File

@@ -2,8 +2,8 @@
错误处理节点 - 处理子图/工具调用错误
"""
from app.main_graph.state import MainGraphState, ErrorSeverity
from app.logger import info
from ...main_graph.state import MainGraphState, ErrorSeverity
from ...logger import info
def error_handling_node(state: MainGraphState) -> MainGraphState:

View File

@@ -4,6 +4,7 @@
"""
from typing import Optional
from langchain_core.runnables.config import RunnableConfig
from ..state import MainGraphState
from ...logger import info, debug
@@ -28,7 +29,7 @@ CHITCHAT_KEYWORDS = {
# ========== 闲聊节点 ==========
async def fast_chitchat_node(state: MainGraphState, config: Optional[dict] = None) -> MainGraphState:
async def fast_chitchat_node(state: MainGraphState, config: Optional[RunnableConfig] = None) -> MainGraphState:
"""快速闲聊节点"""
state.current_phase = "fast_chitchat"
query = state.user_query or ""
@@ -69,14 +70,14 @@ def _match_chitchat_template(query: str) -> str:
# ========== 快速 RAG 节点 ==========
async def fast_rag_node(state: MainGraphState, config: Optional[dict] = None) -> MainGraphState:
async def fast_rag_node(state: MainGraphState, config: Optional[RunnableConfig] = None) -> MainGraphState:
"""快速 RAG 节点:只负责 RAG 检索,然后交给 llm_call 生成回答"""
state.current_phase = "fast_rag"
query = state.user_query or ""
info(f"[Fast RAG] 开始处理: {query[:50]}")
# 获取 RAG 工具
from app.main_graph.utils.rag_initializer import get_rag_tool
from backend.app.main_graph.utils.rag_initializer import get_rag_tool
rag_tool = get_rag_tool()
info(f"[Fast RAG] 获取到 rag_tool: {rag_tool is not None}")
@@ -134,7 +135,7 @@ async def _generate_fast_answer(state: MainGraphState, query: str) -> MainGraphS
请给出简洁、准确的回答:"""
# 使用流式输出
from app.main_graph.config import get_stream_writer
from backend.app.main_graph.config import get_stream_writer
writer = get_stream_writer()
full_content = ""
@@ -164,7 +165,7 @@ async def _generate_fast_answer(state: MainGraphState, query: str) -> MainGraphS
# ========== 快速工具节点 ==========
async def fast_tool_node(state: MainGraphState, config: Optional[dict] = None) -> MainGraphState:
async def fast_tool_node(state: MainGraphState, config: Optional[RunnableConfig] = None) -> MainGraphState:
"""快速工具节点"""
state.current_phase = "fast_tool"

View File

@@ -6,9 +6,9 @@
from typing import Any, Dict
# 本地模块
from app.main_graph.state import MainGraphState
from app.utils.logging import log_state_change
from app.logger import info, warning
from ...main_graph.state import MainGraphState
from ...utils.logging import log_state_change
from ...logger import info, warning
from langchain_core.runnables.config import RunnableConfig
@@ -35,7 +35,7 @@ async def finalize_node(state: MainGraphState, config: RunnableConfig) -> Dict[s
try:
# 获取流式写入器并发送完成事件
from app.main_graph.config import get_stream_writer
from backend.app.main_graph.config import get_stream_writer
writer = get_stream_writer()
# 只在 writer 存在且不是 noop 时才发送

View File

@@ -8,6 +8,7 @@ import json
from typing import Optional
from dataclasses import dataclass, field
from datetime import datetime
from langchain_core.runnables.config import RunnableConfig
from ..state import MainGraphState
from ...logger import info, debug
@@ -157,7 +158,7 @@ def _default_result() -> HybridRouterResult:
# ========== 主路由节点 ==========
async def hybrid_router_node(state: MainGraphState, config: Optional[dict] = None) -> MainGraphState:
async def hybrid_router_node(state: MainGraphState, config: Optional[RunnableConfig] = None) -> MainGraphState:
"""混合路由节点:前置路由,决定走快速路径还是 React 循环"""
state.current_phase = "hybrid_router"
query = state.user_query or ""

View File

@@ -9,10 +9,10 @@ from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage
# 本地模块
from app.main_graph.state import MainGraphState
from app.agent.prompts import create_system_prompt
from app.utils.logging import log_state_change
from app.logger import debug, info, error
from ...main_graph.state import MainGraphState
from ...agent.prompts import create_system_prompt
from ...utils.logging import log_state_change
from ...logger import debug, info, error
def create_dynamic_llm_call_node(chat_services: Dict[str, BaseChatModel], tools: list):
@@ -82,6 +82,12 @@ def create_dynamic_llm_call_node(chat_services: Dict[str, BaseChatModel], tools:
try:
# 添加上下文到消息
messages_with_context = list(state.messages)
info(f"[llm_call] 原始消息数量: {len(messages_with_context)}")
for i, msg in enumerate(messages_with_context):
msg_type = getattr(msg, 'type', 'unknown')
msg_content = getattr(msg, 'content', '')[:100] if hasattr(msg, 'content') else str(msg)[:100]
info(f"[llm_call] msg[{i}] type={msg_type}, content={repr(msg_content)}")
if state.rag_context:
from langchain_core.messages import SystemMessage
rag_system_msg = SystemMessage(content=f"以下是检索到的相关信息:\n{state.rag_context}")
@@ -93,11 +99,13 @@ def create_dynamic_llm_call_node(chat_services: Dict[str, BaseChatModel], tools:
break
if not inserted:
messages_with_context.insert(0, rag_system_msg)
info(f"[llm_call] RAG上下文已添加长度: {len(state.rag_context)}")
# 恢复为:手动进行 astream并将所有的 chunk 拼接成最终的 response 返回。
# LangGraph 会自动监听这期间产生的所有 token。
chain = prompt | llm_with_tools
chunks = []
info(f"[llm_call] 开始调用 LLM astream...")
async for chunk in chain.astream(
{
"messages": messages_with_context,
@@ -106,7 +114,26 @@ def create_dynamic_llm_call_node(chat_services: Dict[str, BaseChatModel], tools:
config=config
):
chunks.append(chunk)
info(f"[llm_call] LLM astream 完成,共收到 {len(chunks)} 个 chunks")
for i, chunk in enumerate(chunks[:10]): # 只打印前10个避免日志过多
chunk_type = type(chunk).__name__
chunk_content = getattr(chunk, 'content', '') if hasattr(chunk, 'content') else str(chunk)
# 打印更多属性
additional_kwargs = getattr(chunk, 'additional_kwargs', {}) or {}
response_metadata = getattr(chunk, 'response_metadata', {}) or {}
# 打印所有属性
info(f"[llm_call] chunk[{i}] type={chunk_type}")
info(f"[llm_call] chunk[{i}] content长度={len(chunk_content) if chunk_content else 0}, content={repr(chunk_content[:200] if chunk_content else '')}")
info(f"[llm_call] chunk[{i}] additional_kwargs={additional_kwargs}")
info(f"[llm_call] chunk[{i}] response_metadata keys={list(response_metadata.keys()) if response_metadata else []}")
info(f"[llm_call] chunk[{i}] response_metadata={response_metadata}")
# 检查是否有其他可能存储内容的属性
if hasattr(chunk, 'tool_call_chunks'):
info(f"[llm_call] chunk[{i}] tool_call_chunks={chunk.tool_call_chunks}")
if hasattr(chunk, 'usage_metadata'):
info(f"[llm_call] chunk[{i}] usage_metadata={chunk.usage_metadata}")
# 将所有 chunk 合并成最终的 AIMessage
if chunks:
response = chunks[0]
@@ -114,6 +141,7 @@ def create_dynamic_llm_call_node(chat_services: Dict[str, BaseChatModel], tools:
response = response + chunk
else:
response = AIMessage(content="")
info(f"[llm_call] ⚠️ 警告: 没有收到任何 chunks")
elapsed_time = time.time() - start_time

View File

@@ -1,8 +1,8 @@
from typing import Any, Dict
from langchain_core.runnables.config import RunnableConfig
from app.main_graph.state import MainGraphState
from app.memory.mem0_client import Mem0Client
from app.logger import info
from ...main_graph.state import MainGraphState
from ...memory.mem0_client import Mem0Client
from ...logger import info
# 全局变量,在 GraphBuilder 中注入

View File

@@ -7,16 +7,17 @@ import time
import asyncio
from typing import Optional
from datetime import datetime
from langchain_core.runnables.config import RunnableConfig
from app.main_graph.state import MainGraphState, ErrorRecord, ErrorSeverity
from app.main_graph.utils.retry_utils import RAG_RETRY_CONFIG
from app.logger import info
from ...main_graph.state import MainGraphState, ErrorRecord, ErrorSeverity
from ...main_graph.utils.retry_utils import RAG_RETRY_CONFIG
from ...logger import info
from ._utils import dispatch_custom_event, make_react_event
def _get_rag_tool() -> Optional[callable]:
"""获取 RAG 工具"""
from app.main_graph.utils.rag_initializer import get_rag_tool
from backend.app.main_graph.utils.rag_initializer import get_rag_tool
return get_rag_tool()
@@ -35,6 +36,9 @@ async def _rag_retrieve_core(state: MainGraphState, rag_tool: callable) -> MainG
# 调用 RAG 工具
rag_context = await rag_tool.ainvoke(retrieval_query)
info(f"[RAG Core] 获取到 rag_context: {type(rag_context)}, 长度={len(rag_context) if rag_context else 0}")
info(f"[RAG Core] ========== RAG 返回的知识内容 ==========")
info(f"{rag_context[:500]}..." if len(rag_context) > 500 else rag_context)
info(f"[RAG Core] ========================================")
state.rag_context = rag_context
state.rag_docs = [{"source": "rag_retrieval", "content": rag_context}]
@@ -47,7 +51,7 @@ async def _rag_retrieve_core(state: MainGraphState, rag_tool: callable) -> MainG
# ========== RAG 检索节点 ==========
async def rag_retrieve_node(state: MainGraphState, config: Optional[dict] = None) -> MainGraphState:
async def rag_retrieve_node(state: MainGraphState, config: Optional[RunnableConfig] = None) -> MainGraphState:
"""RAG 检索节点:带超时和重试"""
state.current_phase = "rag_retrieving"
start_time = time.time()
@@ -156,7 +160,7 @@ async def rag_retrieve_node(state: MainGraphState, config: Optional[dict] = None
return state
async def rag_re_retrieve_node(state: MainGraphState, config: Optional[dict] = None) -> MainGraphState:
async def rag_re_retrieve_node(state: MainGraphState, config: Optional[RunnableConfig] = None) -> MainGraphState:
"""重新检索节点"""
state.current_phase = "rag_re_retrieving"

View File

@@ -3,16 +3,17 @@ React 推理节点
使用 intent.py 进行意图推理
"""
from typing import Dict, Any, Optional
from typing import Optional
from datetime import datetime
from langchain_core.runnables.config import RunnableConfig
from app.core.intent import react_reason_async, ReasoningResult
from app.main_graph.state import MainGraphState
from app.logger import info
from ...core.intent import react_reason_async, ReasoningResult
from ...main_graph.state import MainGraphState
from ...logger import info
from ._utils import dispatch_custom_event, make_react_event
async def react_reason_node(state: MainGraphState, config: Optional[Dict[str, Any]] = None) -> MainGraphState:
async def react_reason_node(state: MainGraphState, config: Optional[RunnableConfig] = None) -> MainGraphState:
"""React 模式推理节点:判断下一步做什么"""
state.current_phase = "react_reasoning"
state.reasoning_step += 1

View File

@@ -6,10 +6,10 @@
from typing import Any, Dict
# 本地模块
from app.main_graph.state import MainGraphState
from app.memory.mem0_client import Mem0Client
from app.utils.logging import log_state_change
from app.logger import debug
from ...main_graph.state import MainGraphState
from ...memory.mem0_client import Mem0Client
from ...utils.logging import log_state_change
from ...logger import debug
def create_retrieve_memory_node(mem0_client: Mem0Client):
@@ -67,10 +67,10 @@ def create_retrieve_memory_node(mem0_client: Mem0Client):
else:
debug("🔍 [记忆检索] 未找到相关记忆")
except Exception as e:
from app.logger import warning
from backend.app.logger import warning
warning(f"⚠️ Mem0 检索失败: {e}")
else:
from app.logger import warning
from backend.app.logger import warning
warning("⚠️ Mem0 未初始化,跳过记忆检索")
memory_context = "\n\n".join(memory_text_parts) if memory_text_parts else "暂无用户信息"

View File

@@ -10,9 +10,9 @@
from datetime import datetime
from app.core.intent import get_route_by_reasoning, ReasoningAction
from app.main_graph.state import MainGraphState
from app.logger import info
from ...core.intent import get_route_by_reasoning, ReasoningAction
from ...main_graph.state import MainGraphState
from ...logger import info
# ========== 初始化状态节点 ==========

View File

@@ -6,10 +6,10 @@
from typing import Any, Dict
# 本地模块
from app.main_graph.state import MainGraphState
from app.memory.mem0_client import Mem0Client
from app.utils.logging import log_state_change
from app.logger import debug, info, error, warning
from ...main_graph.state import MainGraphState
from ...memory.mem0_client import Mem0Client
from ...utils.logging import log_state_change
from ...logger import debug, info, error, warning
def create_summarize_node(mem0_client: Mem0Client):

View File

@@ -6,12 +6,12 @@
import asyncio
from typing import Any, Dict
from langchain_core.messages import AIMessage, ToolMessage
from app.main_graph.config import get_stream_writer
from ...main_graph.config import get_stream_writer
# 本地模块
from app.main_graph.state import MainGraphState
from app.utils.logging import log_state_change
from app.logger import debug, info
from ...main_graph.state import MainGraphState
from ...utils.logging import log_state_change
from ...logger import debug, info
def create_tool_call_node(tools_by_name: Dict[str, Any]):
"""

View File

@@ -2,14 +2,15 @@
联网搜索节点 - 执行搜索并将结果保存到状态
"""
from typing import Dict, Any, Optional
from typing import Optional
from datetime import datetime
from langchain_core.runnables.config import RunnableConfig
from app.main_graph.state import MainGraphState, ErrorRecord, ErrorSeverity
from app.logger import info
from ...main_graph.state import MainGraphState, ErrorRecord, ErrorSeverity
from ...logger import info
async def web_search_node(state: MainGraphState, config: Optional[Dict[str, Any]] = None) -> MainGraphState:
async def web_search_node(state: MainGraphState, config: Optional[RunnableConfig] = None) -> MainGraphState:
"""
联网搜索节点:执行搜索并将结果保存到状态
"""
@@ -39,7 +40,7 @@ async def web_search_node(state: MainGraphState, config: Optional[Dict[str, Any]
search_query = reasoning_result.metadata.get("search_query", state.user_query) if reasoning_result else state.user_query
try:
from app.core import web_search
from backend.app.core import web_search
print(f"[WebSearch] 搜索: {search_query}")
search_result = web_search(search_query, max_results=5)

View File

@@ -5,6 +5,8 @@
from typing import Dict, Any, Optional
from datetime import datetime
from langchain_core.runnables.config import RunnableConfig
from .state import MainGraphState, ErrorRecord, ErrorSeverity
from ..logger import info
@@ -19,7 +21,7 @@ def wrap_subgraph_for_error_handling(subgraph, name: str):
Returns: 包装后的节点函数
"""
async def wrapped_node(state: MainGraphState, config: Optional[Dict[str, Any]] = None) -> MainGraphState:
async def wrapped_node(state: MainGraphState, config: Optional[RunnableConfig] = None) -> MainGraphState:
# 发送子图开始事件
if config:
try:

View File

@@ -20,7 +20,7 @@ def web_search_tool(query: str, max_results: int = 5) -> str:
格式化的搜索结果,包含引用溯源
"""
try:
from app.core import web_search
from backend.app.core import web_search
return web_search(query, max_results)
except Exception as e:
return f"联网搜索出错:{str(e)}"
@@ -40,7 +40,7 @@ def generate_chart_tool(data_text: str, chart_type: str = "bar") -> str:
格式化的图表输出Mermaid 格式)
"""
try:
from app.core import generate_chart
from backend.app.core import generate_chart
return generate_chart(data_text, chart_type)
except Exception as e:
return f"生成图表出错:{str(e)}\n\n请使用格式:标题,标签1:值1,标签2:值2,..."

View File

@@ -22,13 +22,13 @@ def dictionary_tool(query: str, action: Optional[str] = None) -> str:
格式化的结果文本
"""
try:
from app.subgraphs.dictionary import (
from backend.app.subgraphs.dictionary import (
DictionaryState,
DictionaryAction,
parse_intent,
format_result
)
from app.subgraphs.dictionary.nodes import (
from backend.app.subgraphs.dictionary.nodes import (
query_word, translate_text, extract_terms, get_daily_word
)
@@ -87,13 +87,13 @@ def news_analysis_tool(query: str, action: Optional[str] = None) -> str:
格式化的结果文本
"""
try:
from app.subgraphs.news_analysis import (
from backend.app.subgraphs.news_analysis import (
NewsAnalysisState,
NewsAction,
parse_intent,
format_result
)
from app.subgraphs.news_analysis.nodes import (
from backend.app.subgraphs.news_analysis.nodes import (
query_news, analyze_url, extract_keywords, generate_report
)
@@ -150,13 +150,13 @@ def contact_tool(query: str, action: Optional[str] = None) -> str:
格式化的结果文本
"""
try:
from app.subgraphs.contact import (
from backend.app.subgraphs.contact import (
ContactState,
ContactAction,
parse_intent,
format_result
)
from app.subgraphs.contact.nodes import (
from backend.app.subgraphs.contact.nodes import (
query_contact, add_contact, list_contacts
)

View File

@@ -1,8 +1,8 @@
# app/rag_initializer.py
from app.rag.tools import create_rag_tool
from app.rag.retriever import create_parent_hybrid_retriever
from app.model_services import get_embedding_service
from app.logger import info, warning
from ...rag.tools import create_rag_tool
from ...rag.retriever import create_parent_hybrid_retriever
from ...model_services import get_embedding_service
from ...logger import info, warning
import sys
# 全局 RAG 工具
@@ -38,7 +38,7 @@ async def init_rag_tool(force: bool = False):
return _rag_tool
try:
from app.model_services.chat_services import get_chat_service
from backend.app.model_services.chat_services import get_chat_service
info("🔄 正在初始化 RAG 检索系统...")
embeddings = get_embedding_service()

View File

@@ -287,7 +287,7 @@ def create_retry_wrapper_for_node(
time.sleep(delay)
# 所有重试都失败,更新状态错误信息
from app.main_graph.state import ErrorRecord, ErrorSeverity
from backend.app.main_graph.state import ErrorRecord, ErrorSeverity
error_record = ErrorRecord(
error_type=f"{node_name}TimeoutError",

View File

@@ -9,7 +9,7 @@ from typing import Optional, List
from mem0 import AsyncMemory
from app.config import (
from ..config import (
LLM_API_KEY,
ZHIPUAI_API_KEY,
VLLM_BASE_URL,
@@ -21,9 +21,9 @@ from app.config import (
ZHIPU_EMBEDDING_MODEL,
ZHIPU_API_BASE,
)
from app.logger import info, warning, error
from app.model_services import get_embedding_service
from app.model_services.chat_services import get_chat_service
from ..logger import info, warning, error
from ..model_services import get_embedding_service
from ..model_services.chat_services import get_chat_service
class Mem0Client:
@@ -48,7 +48,7 @@ class Mem0Client:
info(f"✅ 嵌入服务可用,向量维度: {embedding_dim}")
# 构建 embedder 配置
from app.model_services.embedding_services import (
from backend.app.model_services.embedding_services import (
LocalLlamaCppEmbeddingProvider,
ZhipuEmbeddingProvider,
)

View File

@@ -23,7 +23,7 @@ from .base import (
FallbackServiceChain,
SingletonServiceManager
)
from app.config import (
from ..config import (
VLLM_BASE_URL,
LLM_API_KEY,
ZHIPUAI_API_KEY,
@@ -203,7 +203,7 @@ class LocalSmallModelProvider(BaseServiceProvider[BaseChatModel]):
"""
def __init__(self, model: str = None):
from app.config import SMALL_LOCAL_MODEL_NAME, SMALL_VLLM_BASE_URL, SMALL_LLM_API_KEY
from ..config import SMALL_LOCAL_MODEL_NAME, SMALL_VLLM_BASE_URL, SMALL_LLM_API_KEY
super().__init__("local_small")
self._model = model or SMALL_LOCAL_MODEL_NAME
self._base_url = SMALL_VLLM_BASE_URL
@@ -242,7 +242,7 @@ class DeepSeekSmallModelProvider(BaseServiceProvider[BaseChatModel]):
"""
def __init__(self, model: str = None):
from app.config import SMALL_DEEPSEEK_MODEL, SMALL_DEEPSEEK_API_KEY, SMALL_DEEPSEEK_API_BASE
from ..config import SMALL_DEEPSEEK_MODEL, SMALL_DEEPSEEK_API_KEY, SMALL_DEEPSEEK_API_BASE
super().__init__("deepseek_small")
self._model = model or SMALL_DEEPSEEK_MODEL
self._api_key = SMALL_DEEPSEEK_API_KEY

View File

@@ -21,7 +21,7 @@ from .base import (
FallbackServiceChain,
SingletonServiceManager
)
from app.config import (
from ..config import (
LLAMACPP_EMBEDDING_URL,
LLAMACPP_API_KEY,
ZHIPUAI_API_KEY,

View File

@@ -27,7 +27,7 @@ from .base import (
FallbackServiceChain,
SingletonServiceManager
)
from app.config import (
from ..config import (
LLAMACPP_RERANKER_URL,
LLAMACPP_API_KEY,
ZHIPUAI_API_KEY,
@@ -92,20 +92,33 @@ class LocalLlamaCppRerankService(BaseRerankService):
"documents": documents,
}
logger.info(f"[LocalLlamaCppRerank] 调用 rerank API: {base}/rerank")
logger.info(f"[LocalLlamaCppRerank] 请求 payload: query={query[:50]}, documents数量={len(documents)}")
with httpx.Client(timeout=120) as client:
response = client.post(
f"{base}/rerank",
headers=headers,
json=payload,
)
response.raise_for_status()
logger.info(f"[LocalLlamaCppRerank] 响应状态码: {response.status_code}")
if response.status_code != 200:
logger.error(f"[LocalLlamaCppRerank] 请求失败: {response.status_code}")
logger.error(f"[LocalLlamaCppRerank] 响应内容: {response.text[:500]}")
response.raise_for_status()
data = response.json()
logger.info(f"[LocalLlamaCppRerank] 响应数据类型: {type(data)}")
if isinstance(data, dict) and "results" in data:
results = data["results"]
results_sorted = sorted(results, key=lambda x: x["index"])
return [item["relevance_score"] for item in results_sorted]
scores = [item["relevance_score"] for item in results_sorted]
logger.info(f"[LocalLlamaCppRerank] 返回 {len(scores)} 个得分")
return scores
else:
logger.error(f"[LocalLlamaCppRerank] 未知响应格式: {type(data)}")
raise ValueError(f"未知的 rerank API 响应格式: {data}")
@@ -207,11 +220,13 @@ class LLMFallbackRerankService(BaseRerankService):
if not documents:
return []
logger.info(f"[LLMFallbackRerank] 开始为 {len(documents)} 个文档打分")
scores = []
for doc in documents:
for i, doc in enumerate(documents):
score = self._score_single_document(query, doc)
scores.append(score)
logger.info(f"[LLMFallbackRerank] doc[{i}] score={score:.4f}")
return scores
def _score_single_document(self, query: str, document: str) -> float:

View File

@@ -13,7 +13,7 @@ RAG 检索与生成模块
用户查询 → 多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档
示例用法:
>>> from app.rag.rag import RAGPipeline, create_rag_tool
>>> from backend.app.rag.rag import RAGPipeline, create_rag_tool
>>> from rag_indexer.builder import IndexBuilder, IndexBuilderConfig
>>> from langchain_openai import ChatOpenAI
>>>

View File

@@ -1,137 +1,114 @@
"""
RAG 检索流水线模块
提供固定流程的 RAG 检索:
多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档
默认使用混合检索(稠密+稀疏)+ 父子文档模式。
RAG 检索流水线
流程: 检索子文档 → 重排 → 获取父文档 → 返回
"""
import asyncio
import os
from typing import List, Optional
import logging
from typing import List
from langchain_core.documents import Document
from langchain_core.language_models import BaseLanguageModel
from app.model_services import get_rerank_service, get_small_llm_service
from app.rag.rerank import create_document_reranker
from app.rag.query_transform import MultiQueryGenerator
from app.rag.fusion import reciprocal_rank_fusion
from app.rag.retriever import create_parent_hybrid_retriever
from ..model_services import get_rerank_service, get_small_llm_service
from ..rag.rerank import create_document_reranker
from ..rag.query_transform import MultiQueryGenerator
from ..rag.fusion import reciprocal_rank_fusion
from ..rag.retriever import create_parent_hybrid_retriever
logger = logging.getLogger(__name__)
class RAGPipeline:
"""
固定流程的 RAG 检索流水线:
多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档
默认使用混合检索(稠密+BM25稀疏+ 父子文档模式。
"""
def __init__(
self,
retriever=None,
llm: Optional[BaseLanguageModel] = "default_small",
llm: BaseLanguageModel | str = "default_small",
num_queries: int = 3,
rerank_top_n: int = 5,
collection_name: str = "rag_documents",
use_rerank: bool = True,
return_parent_docs: bool = True,
):
"""
Args:
retriever: 基础检索器对象,需实现 ainvoke(query) 异步方法。
如果不提供,会自动创建默认的父子文档混合检索器。
llm: 用于生成多路查询的语言模型。
- "default_small": (默认) 使用小模型(本地 + DeepSeek
- None / False: 不做查询改写
- BaseLanguageModel 实例: 自定义模型
num_queries: 生成的查询变体数量。
rerank_top_n: 最终返回的文档数量。
collection_name: Qdrant 集合名称(仅当 retriever 未提供时使用)。
"""
# 如果没有提供 retriever自动创建默认的混合检索器
if retriever is None:
self.retriever = create_parent_hybrid_retriever(
collection_name=collection_name,
search_k=rerank_top_n * 2 # 多取一些给重排序用
)
else:
self.retriever = retriever
# 处理 llm 参数
self.retriever = retriever or create_parent_hybrid_retriever(
collection_name=collection_name, search_k=rerank_top_n * 4
)
self.num_queries = num_queries
self.rerank_top_n = rerank_top_n
self.use_rerank = use_rerank
self.return_parent_docs = return_parent_docs
if llm == "default_small":
try:
self.llm = get_small_llm_service()
except Exception as e:
import logging
logger = logging.getLogger(__name__)
logger.warning(f"小模型初始化失败,将不做查询改写: {e}")
except Exception:
self.llm = None
elif llm in (None, False):
self.llm = None
else:
self.llm = llm
self.num_queries = num_queries
self.rerank_top_n = rerank_top_n
# 初始化组件 - 使用统一的重排服务获取接口
self.query_generator = MultiQueryGenerator(llm=self.llm, num_queries=num_queries) if self.llm else None
self.reranker = create_document_reranker()
self.llm = llm if llm else None
self.query_generator = MultiQueryGenerator(self.llm, num_queries) if self.llm else None
self.reranker = create_document_reranker() if use_rerank else None
logger.info(f"[Pipeline] init: rerank={use_rerank}, return_parent={return_parent_docs}")
async def aretrieve(self, query: str) -> List[Document]:
"""
异步执行完整检索流程
Args:
query: 用户查询
Returns:
检索到的相关文档列表
"""
# 如果有 query_generator做多路改写
if self.query_generator and self.llm:
# Step 1: 生成多路查询
# Step 1: 检索
child_docs = await self._retrieve(query)
logger.info(f"[Pipeline] 检索到 {len(child_docs)} 个子文档")
# 调试:打印子文档长度
for i, doc in enumerate(child_docs[:5]):
content_len = len(doc.page_content)
logger.info(f"[Pipeline] 子文档[{i}] 长度={content_len}字符")
# Step 2: 重排
if self.reranker:
try:
child_docs = self.reranker.compress_documents(child_docs, query, self.rerank_top_n)
logger.info(f"[Pipeline] 重排后 {len(child_docs)}")
except Exception as e:
logger.warning(f"[Pipeline] 重排失败: {e}")
child_docs = child_docs[:self.rerank_top_n]
# Step 3: 获取父文档
if self.return_parent_docs:
return await self._get_parents(child_docs)
return child_docs
async def _retrieve(self, query: str) -> List[Document]:
if self.query_generator:
queries = await self.query_generator.agenerate(query)
# 包含原始查询,确保至少有一条
if query not in queries:
queries.insert(0, query)
else:
# 如果原始查询已在列表中,将其移至首位
queries.remove(query)
queries.insert(0, query)
# Step 2: 并行检索(每个查询获取文档列表)
tasks = [self.retriever.ainvoke(q) for q in queries]
doc_lists = await asyncio.gather(*tasks)
# Step 3: RRF 融合
fused_docs = reciprocal_rank_fusion(doc_lists)
else:
# 没有 LLM 做查询改写,直接用原始查询检索
fused_docs = await self.retriever.ainvoke(query)
# Step 4: 重排序
queries = [query] + [q for q in queries if q != query]
doc_lists = await asyncio.gather(*[self.retriever.ainvoke(q) for q in queries])
return reciprocal_rank_fusion(doc_lists)
return await self.retriever.ainvoke(query)
async def _get_parents(self, child_docs: List[Document]) -> List[Document]:
parent_map = {}
for doc in child_docs:
pid = doc.metadata.get("parent_id")
if pid and pid not in parent_map:
parent_map[pid] = doc.metadata.get("score", 0.0)
if not parent_map:
logger.warning("[Pipeline] 未找到 parent_id返回子文档")
return child_docs
try:
final_docs = self.reranker.compress_documents(fused_docs, query, top_n=self.rerank_top_n)
except Exception:
# 若重排序器不可用,直接返回融合后的前 N 个结果
final_docs = fused_docs[:self.rerank_top_n]
return final_docs
from backend.rag_core import create_docstore
docstore, _ = create_docstore()
# 同步获取(异步版本不存在)
parent_docs = docstore.mget(list(parent_map.keys()))
parent_map2 = {d.metadata.get("id"): d for d in parent_docs if d}
result = [(parent_map2[pid], score) for pid, score in parent_map.items() if pid in parent_map2]
result.sort(key=lambda x: x[1], reverse=True)
docs = [d for d, _ in result]
logger.info(f"[Pipeline] 获取到 {len(docs)} 个父文档")
return docs
except Exception as e:
logger.warning(f"[Pipeline] 获取父文档失败: {e}")
return child_docs
def format_context(self, documents: List[Document]) -> str:
"""
将文档列表格式化为上下文字符串
Args:
documents: 文档列表
Returns:
格式化后的上下文字符串
"""
if not documents:
return ""
parts = []
for i, doc in enumerate(documents, 1):
source = doc.metadata.get("source", "未知来源")
@@ -139,30 +116,5 @@ class RAGPipeline:
return "\n".join(parts)
def create_rag_pipeline(
collection_name: str = "rag_documents",
llm: Optional[BaseLanguageModel] = "default_small",
num_queries: int = 3,
rerank_top_n: int = 5,
) -> RAGPipeline:
"""
创建 RAG 检索流水线的便捷函数
Args:
collection_name: Qdrant 集合名称
llm: 用于生成多路查询的语言模型。
- "default_small": (默认) 使用小模型(本地 + DeepSeek
- None / False: 不做查询改写
- BaseLanguageModel 实例: 自定义模型
num_queries: 生成的查询变体数量
rerank_top_n: 最终返回的文档数量
Returns:
RAGPipeline 实例
"""
return RAGPipeline(
llm=llm,
num_queries=num_queries,
rerank_top_n=rerank_top_n,
collection_name=collection_name
)
def create_rag_pipeline(**kwargs) -> RAGPipeline:
return RAGPipeline(**kwargs)

View File

@@ -57,14 +57,26 @@ class DocumentReranker:
try:
# 1. 从 Document 提取内容(业务逻辑)
doc_contents = [doc.page_content for doc in documents]
logger.info(f"[Rerank] 收到 {len(documents)} 个文档待重排, query={query[:50]}")
total_chars = sum(len(c) for c in doc_contents)
logger.info(f"[Rerank] 各文档长度: {[len(c) for c in doc_contents]}, 总字符数: {total_chars}")
# 粗略估算 tokens (中文约 0.75 tokens/字符)
estimated_tokens = int(total_chars * 0.75)
logger.info(f"[Rerank] 估算总 tokens: ~{estimated_tokens} (假设中文)")
# 2. 调用纯服务层计算得分
logger.info(f"[Rerank] 正在调用 rerank service: {type(self._rerank_service).__name__}")
scores = self._rerank_service.compute_scores(query, doc_contents)
logger.info(f"[Rerank] 获取到 {len(scores)} 个得分: {scores}")
# 3. 根据得分排序(业务逻辑)
doc_score_pairs = list(zip(documents, scores))
doc_score_pairs_sorted = sorted(doc_score_pairs, key=lambda x: x[1], reverse=True)
logger.info(f"[Rerank] 排序后的结果:")
for i, (doc, score) in enumerate(doc_score_pairs_sorted):
logger.info(f" [{i}] score={score:.4f}, content={doc.page_content[:80]}...")
# 4. 取 top_n
top_docs = [pair[0] for pair in doc_score_pairs_sorted[:top_n]]
@@ -72,6 +84,9 @@ class DocumentReranker:
except Exception as e:
logger.warning(f"重排过程出错,返回原始前 {top_n} 个结果: {e}")
logger.warning(f"[Rerank] 异常详情: {type(e).__name__}: {e}")
import traceback
logger.warning(f"[Rerank] 堆栈: {traceback.format_exc()}")
return documents[:top_n]

View File

@@ -19,10 +19,10 @@ from langchain_core.embeddings import Embeddings
from langchain_core.retrievers import BaseRetriever
from pydantic import Field, PrivateAttr
from rag_core import QdrantHybridStore, get_sparse_embedder, create_docstore
from rag_core.client import create_async_qdrant_client
from app.model_services import get_embedding_service
from app.logger import info, warning, debug
from backend.rag_core import QdrantHybridStore, get_sparse_embedder, create_docstore
from backend.rag_core.client import create_async_qdrant_client
from ..model_services import get_embedding_service
from ..logger import info, warning, debug
# 模块级常量
@@ -131,20 +131,20 @@ class HybridRetriever(BaseRetriever):
class ParentHybridRetriever(BaseRetriever):
"""
父子文档混合检索器(异步):
1. 先用混合检索找到相关子文档
2. 根据子文档的 parent_id 找到对应的父文档
3. 去重并返回父文档
"""
collection_name: str = Field(description="Qdrant 集合名称")
search_k: int = Field(default=DEFAULT_PARENT_SEARCH_K, description="检索返回结果数")
_vector_store: Any = PrivateAttr()
_client: Any = PrivateAttr()
_sparse_embedder: Any = PrivateAttr()
_docstore: Any = PrivateAttr()
def __init__(
self,
collection_name: str,
@@ -188,7 +188,7 @@ class ParentHybridRetriever(BaseRetriever):
self, query: str, *, run_manager: Any = None
) -> List[Document]:
"""
异步检索相关文档
异步检索相关文档
"""
# 1. 生成查询向量
dense_query = await self._vector_store.aembed_query(query)
@@ -197,10 +197,10 @@ class ParentHybridRetriever(BaseRetriever):
indices=sparse_query["indices"],
values=sparse_query["values"]
)
# 2. 多取一些子文档,避免去重后数量不足
search_limit = self.search_k * 2
# 3. 使用 query_points API 进行混合检索
response = await self._client.query_points(
collection_name=self.collection_name,
@@ -220,87 +220,27 @@ class ParentHybridRetriever(BaseRetriever):
limit=search_limit,
with_payload=True
)
if not response.points:
debug("混合检索未找到任何文档")
return []
# 4. 收集 parent_id 和对应最高得分
parent_score_map = {}
parent_ids = set()
child_point_map = {} # 保存子文档点用于降级
# 4. 构建子文档列表
child_docs = []
for point in response.points:
payload_copy = point.payload.copy()
parent_id = payload_copy.get("parent_id", point.id)
score = point.score
if parent_id not in parent_score_map or score > parent_score_map[parent_id]:
parent_score_map[parent_id] = score
parent_ids.add(parent_id)
child_point_map[parent_id] = point
# 5. 批量查询父文档
parent_docs = []
found_parent_ids = set()
# 先尝试从 Qdrant 直接查询(如果父文档也在 Qdrant 中)
try:
parent_points = await self._client.retrieve(
collection_name=self.collection_name,
ids=list(parent_ids),
with_payload=True
doc = Document(
page_content=payload_copy.pop("page_content", payload_copy.pop("text", "")),
metadata={
**payload_copy,
"child_id": point.id,
"score": point.score
}
)
for point in parent_points:
payload_copy = point.payload.copy()
doc = Document(
page_content=payload_copy.pop("page_content", payload_copy.pop("text", "")),
metadata=payload_copy
)
parent_docs.append(doc)
found_parent_ids.add(point.id)
except Exception as e:
warning(f"从 Qdrant 查询父文档失败: {e}")
# 6. 如果有 docstore尝试从 docstore 查询剩余的父文档
if self._docstore and len(found_parent_ids) < len(parent_ids):
missing_parent_ids = parent_ids - found_parent_ids
try:
docstore_docs = await self._docstore.amget(missing_parent_ids)
for doc_id, doc in zip(missing_parent_ids, docstore_docs):
if doc is not None:
parent_docs.append(doc)
found_parent_ids.add(doc_id)
except Exception as e:
warning(f"从 docstore 查询父文档失败: {e}")
# 7. 降级:对于仍未找到的父文档,用子文档本身代替
missing_parent_ids = parent_ids - found_parent_ids
if missing_parent_ids:
warning(f"以下 parent_id 未找到对应的父文档,将返回子文档本身: {missing_parent_ids}")
for parent_id in missing_parent_ids:
child_point = child_point_map.get(parent_id)
if child_point:
payload_copy = child_point.payload.copy()
doc = Document(
page_content=payload_copy.pop("page_content", payload_copy.pop("text", "")),
metadata=payload_copy
)
parent_docs.append(doc)
# 8. 按照得分降序排序,返回前 k 个
parent_docs_with_scores = [
(doc, parent_score_map.get(doc.metadata.get("id", doc.id if hasattr(doc, "id") else ""), 0.0))
for doc in parent_docs
]
parent_docs_with_scores.sort(key=lambda x: x[1], reverse=True)
final_docs = [doc for doc, _ in parent_docs_with_scores[:self.search_k]]
debug(f"父子文档混合检索返回 {len(final_docs)} 个父文档")
return final_docs
child_docs.append(doc)
debug(f"父子文档混合检索返回 {len(child_docs)} 个子文档")
return child_docs
def create_hybrid_retriever(

View File

@@ -10,7 +10,7 @@ from typing import Callable, Optional
from langchain_core.tools import tool
from langchain_core.language_models import BaseLanguageModel
from langchain_core.retrievers import BaseRetriever
from app.rag.pipeline import RAGPipeline, create_rag_pipeline
from ..rag.pipeline import RAGPipeline, create_rag_pipeline
def create_rag_tool(

View File

@@ -20,22 +20,18 @@ class ContactAPIClient:
def __init__(self, conn=None):
"""
初始化
初始化使用延迟初始化MCP在首次使用时初始化
Args:
conn: 数据库连接(保留用于向后兼容)
"""
self.conn = conn
# 确保MCP已初始化
import asyncio
try:
asyncio.create_task(self._init_mcp())
except RuntimeError:
pass # 没有事件循环时跳过,延迟初始化
self._mcp_initialized = False # 延迟初始化标志
async def _init_mcp(self):
"""初始化MCP系统"""
"""初始化MCP系统(延迟初始化)"""
if self._mcp_initialized:
return # 已初始化,跳过
if not mcp_manager.get_adapter("contact"):
# 获取repository如果有
repo = None
@@ -45,9 +41,10 @@ class ContactAPIClient:
repo = ContactRepository(self.conn)
except Exception:
pass
mcp_manager.register_adapter(ContactAdapter(contact_repo=repo))
await mcp_manager.initialize()
self._mcp_initialized = True
async def list_contacts(self, user_id: str = "default") -> List[Contact]:
"""获取联系人列表"""

View File

@@ -8,7 +8,7 @@ from typing import Dict, Any
from datetime import datetime
# 公共工具
from app.core import MarkdownFormatter
from ...core import MarkdownFormatter
from .state import ContactState
from .api_client import ContactAPIClient

View File

@@ -22,20 +22,19 @@ class DictionaryAPIClient:
word_repository: Optional[Any] = None
def __post_init__(self):
"""初始化后设置MCP"""
import asyncio
try:
asyncio.create_task(self._init_mcp())
except RuntimeError:
pass
"""初始化后设置(延迟初始化标志)"""
self._mcp_initialized = False # 延迟初始化标志
async def _init_mcp(self):
"""初始化MCP系统"""
"""初始化MCP系统(延迟初始化)"""
if self._mcp_initialized:
return # 已初始化,跳过
if not mcp_manager.get_adapter("dictionary"):
mcp_manager.register_adapter(
DictionaryAdapter(word_repo=self.word_repository)
)
await mcp_manager.initialize()
self._mcp_initialized = True
async def query_word(
self,

View File

@@ -8,7 +8,7 @@ from datetime import datetime
import random
# 公共工具
from app.core import (
from ...core import (
MarkdownFormatter
)

View File

@@ -23,20 +23,19 @@ class NewsAPIClient:
news_repository: Optional[Any] = None
def __post_init__(self):
"""初始化后设置MCP"""
import asyncio
try:
asyncio.create_task(self._init_mcp())
except RuntimeError:
pass
"""初始化后设置(延迟初始化标志)"""
self._mcp_initialized = False # 延迟初始化标志
async def _init_mcp(self):
"""初始化MCP系统"""
"""初始化MCP系统(延迟初始化)"""
if self._mcp_initialized:
return # 已初始化,跳过
if not mcp_manager.get_adapter("news"):
mcp_manager.register_adapter(
NewsAdapter(news_repo=self.news_repository)
)
await mcp_manager.initialize()
self._mcp_initialized = True
async def query_news(
self,

View File

@@ -7,7 +7,7 @@ from typing import Dict, Any
from datetime import datetime
# 公共工具
from app.core import MarkdownFormatter
from ...core import MarkdownFormatter
from .state import (
NewsAnalysisState,

View File

@@ -3,9 +3,9 @@ LangGraph 节点日志工具模块
提供状态流转追踪和 LLM 输入输出打印功能
"""
from app.config import ENABLE_GRAPH_TRACE
from app.logger import debug, info
from app.main_graph.state import MainGraphState
from ..config import ENABLE_GRAPH_TRACE
from ..logger import debug, info
from ..main_graph.state import MainGraphState
def log_state_change(node_name: str, state: MainGraphState, prefix: str = "进入"):
@@ -17,7 +17,7 @@ def log_state_change(node_name: str, state: MainGraphState, prefix: str = "进
state: 当前状态
prefix: 日志前缀("进入""离开"
"""
from app.logger import info
from backend.app.logger import info
messages = state.messages
msg_count = len(messages)

View File

@@ -35,6 +35,7 @@ def get_input_path() -> Path:
return Path(sys.argv[1])
# 默认测试路径(可按需修改)
return Path("data/corpus/三国演义.txt")
#return Path("data/user_docs/doublestory.txt")
async def main():

View File

@@ -1,34 +0,0 @@
#!/usr/bin/env python3
"""
删除 Qdrant 集合并重新索引
"""
import asyncio
import os
import sys
from backend.rag_core import QdrantHybridStore
async def delete_and_recreate():
"""删除并重新创建集合"""
print("="*70)
print("删除旧集合并重新创建...")
print("="*70)
vs = QdrantHybridStore(collection_name="rag_documents")
# 删除旧集合
try:
vs.delete_collection()
print("✅ 旧集合已删除")
except Exception as e:
print(f"⚠️ 删除集合时出错(可能不存在): {e}")
# 重新创建
vs.create_collection()
print("✅ 新集合已创建")
if __name__ == "__main__":
asyncio.run(delete_and_recreate())

View File

@@ -17,10 +17,8 @@ class SplitterType(str, Enum):
PARENT_CHILD = "parent_child"
# ---------- 配置数据类,统一参数 ----------
@dataclass
class RecursiveSplitterConfig:
"""递归字符切分器配置"""
chunk_size: int = 500
chunk_overlap: int = 50
separators: List[str] = field(default_factory=lambda: ["\n\n", "\n", "", "", "", " ", ""])
@@ -30,33 +28,31 @@ class RecursiveSplitterConfig:
@dataclass
class SemanticSplitterConfig:
"""语义切分器配置,仅包含 SemanticChunker 支持的参数。"""
embeddings: Any
buffer_size: int = 1
add_start_index: bool = False
breakpoint_threshold_type: str = "percentile"
breakpoint_threshold_amount: Optional[float] = None
breakpoint_threshold_amount: float = 0.6 # 非 None切分更积极
number_of_chunks: Optional[int] = None
sentence_split_regex: str = r"(?<=[.?!。?!])\s+"
sentence_split_regex: str = r"(?<=[。!?;.!?;])" # 中文友好
min_chunk_size: int = 100
@dataclass
class ParentChildSplitterConfig:
"""父子切分器配置"""
embeddings: Any # 子块语义切分所需
parent_chunk_size: int = 1000
parent_chunk_overlap: int = 100
child_buffer_size: int = 1
child_breakpoint_threshold_type: str = "percentile"
child_breakpoint_threshold_amount: Optional[float] = None
child_min_chunk_size: int = 100
child_max_chunk_size: Optional[int] = 200
embeddings: Any
# 语义切分(用于父块)
semantic_threshold_type: str = "percentile"
semantic_threshold_amount: float = 0.6
semantic_buffer_size: int = 1
semantic_min_chunk_size: int = 100
# 子块(递归字符切分)
child_chunk_size: int = 400
child_chunk_overlap: int = 50
# ---------- 适配器:让 SemanticChunker 实现 TextSplitter 接口 ----------
# ---------- 适配器 ----------
class SemanticChunkerAdapter(TextSplitter):
"""将 SemanticChunker 适配为 LangChain TextSplitter 接口。"""
def __init__(self, config: SemanticSplitterConfig, **kwargs):
super().__init__(**kwargs)
self._config = config
@@ -86,12 +82,8 @@ class SemanticChunkerAdapter(TextSplitter):
return result
# ---------- 工厂函数,统一创建切分器 ----------
# ---------- 工厂函数 ----------
def get_splitter(splitter_type: SplitterType, **kwargs) -> TextSplitter:
"""
根据类型创建切分器。
支持传入配置对象或直接参数。
"""
if splitter_type == SplitterType.RECURSIVE:
config = RecursiveSplitterConfig(
chunk_size=kwargs.get("chunk_size", 500),
@@ -114,98 +106,90 @@ def get_splitter(splitter_type: SplitterType, **kwargs) -> TextSplitter:
if "config" in kwargs and isinstance(kwargs["config"], SemanticSplitterConfig):
config = kwargs["config"]
else:
# 过滤出 SemanticSplitterConfig 支持的字段
config_kwargs = {
"embeddings": embeddings,
"buffer_size": kwargs.get("buffer_size", 1),
"breakpoint_threshold_type": kwargs.get("breakpoint_threshold_type", "percentile"),
"breakpoint_threshold_amount": kwargs.get("breakpoint_threshold_amount"),
"number_of_chunks": kwargs.get("number_of_chunks"),
"min_chunk_size": kwargs.get("min_chunk_size", 100),
}
config = SemanticSplitterConfig(**config_kwargs)
config = SemanticSplitterConfig(
embeddings=embeddings,
buffer_size=kwargs.get("buffer_size", 1),
breakpoint_threshold_type=kwargs.get("breakpoint_threshold_type", "percentile"),
breakpoint_threshold_amount=kwargs.get("breakpoint_threshold_amount", 0.6),
number_of_chunks=kwargs.get("number_of_chunks"),
min_chunk_size=kwargs.get("min_chunk_size", 100),
)
return SemanticChunkerAdapter(config)
elif splitter_type == SplitterType.PARENT_CHILD:
# 父子切分器在 builder 中单独处理,不通过本函数创建
raise ValueError("父子切分器应通过 IndexBuilder 创建,不支持 get_splitter 直接构建")
raise ValueError("父子切分器应通过 ParentChildSplitter 直接创建")
else:
raise ValueError(f"不支持的切分器类型: {splitter_type}")
# ---------- 父子切分器实现 ----------
# ---------- 父子切分器 ----------
class ParentChildSplitter:
"""
将文档切分为父块(大块,用于上下文)和子块(小块,用于索引检索)。
内部维护父子块之间的映射关系。
切分流程:
1. 语义切分 → 父块
2. 递归字符切分 → 子块
"""
def __init__(self, config: ParentChildSplitterConfig):
self.config = config
# 父块使用递归字符切分
self.parent_splitter = RecursiveCharacterTextSplitter(
chunk_size=config.parent_chunk_size,
chunk_overlap=config.parent_chunk_overlap,
)
# 子块使用语义切分
# 语义切分(父块)
semantic_config = SemanticSplitterConfig(
embeddings=config.embeddings,
buffer_size=config.child_buffer_size,
breakpoint_threshold_type=config.child_breakpoint_threshold_type,
breakpoint_threshold_amount=config.child_breakpoint_threshold_amount,
min_chunk_size=config.child_min_chunk_size,
buffer_size=config.semantic_buffer_size,
breakpoint_threshold_type=config.semantic_threshold_type,
breakpoint_threshold_amount=config.semantic_threshold_amount,
min_chunk_size=config.semantic_min_chunk_size,
)
self.semantic_splitter = SemanticChunkerAdapter(semantic_config)
# 递归字符切分(子块,大小由 child_chunk_size 控制)
self.recursive_splitter = RecursiveCharacterTextSplitter(
chunk_size=config.child_chunk_size,
chunk_overlap=config.child_chunk_overlap,
separators=["\n\n", "\n", "", "", "", "", "", " ", ""]
)
self.child_splitter = SemanticChunkerAdapter(semantic_config)
# 存储父子块映射关系(可选)
self.parent_to_children: Dict[str, List[str]] = {}
self.child_to_parent: Dict[str, str] = {}
def split_documents(self, documents: List[Document]) -> Tuple[List[Document], List[Document]]:
"""
返回:
(父块列表, 子块列表)
同时填充内部映射字典。
"""
parent_chunks = self.parent_splitter.split_documents(documents)
child_chunks = self.child_splitter.split_documents(documents)
parent_chunks = []
child_chunks = []
# 建立映射关系(简化示例:根据文本包含关系粗略匹配,实际需更精确的算法)
# 这里仅作示意,生产环境建议使用 embedding 相似度或精确子串定位
self._build_mappings(parent_chunks, child_chunks)
for doc in documents:
# Step 1: 语义切分(父块)
semantic_blocks = self.semantic_splitter.split_text(doc.page_content)
for p_idx, semantic_block in enumerate(semantic_blocks):
parent_id = f"parent_{len(parent_chunks)}"
parent_doc = Document(
page_content=semantic_block,
metadata={**doc.metadata, "id": parent_id, "chunk_index": p_idx}
)
parent_chunks.append(parent_doc)
# Step 2: 递归字符切分(子块)
sub_chunks = self.recursive_splitter.split_text(semantic_block)
for c_idx, sub_chunk in enumerate(sub_chunks):
child_id = f"child_{len(child_chunks)}"
child_doc = Document(
page_content=sub_chunk,
metadata={**doc.metadata, "id": child_id, "parent_id": parent_id, "child_index": c_idx}
)
child_chunks.append(child_doc)
self.child_to_parent[child_id] = parent_id
if parent_id not in self.parent_to_children:
self.parent_to_children[parent_id] = []
self.parent_to_children[parent_id].append(child_id)
return parent_chunks, child_chunks
def _build_mappings(self, parents: List[Document], children: List[Document]) -> None:
"""
根据文本内容建立父子映射。
本方法为简化实现,实际使用时请替换为更可靠的匹配逻辑。
"""
self.parent_to_children.clear()
self.child_to_parent.clear()
# 为每个父块生成唯一 ID若无则使用索引
for p_idx, parent in enumerate(parents):
parent_id = parent.metadata.get("id", f"parent_{p_idx}")
parent.metadata["id"] = parent_id
self.parent_to_children[parent_id] = []
# 将每个子块分配给包含其文本的第一个父块
for c_idx, child in enumerate(children):
child_id = child.metadata.get("id", f"child_{c_idx}")
child.metadata["id"] = child_id
for parent in parents:
if child.page_content in parent.page_content:
parent_id = parent.metadata["id"]
self.parent_to_children[parent_id].append(child_id)
self.child_to_parent[child_id] = parent_id
child.metadata["parent_id"] = parent_id
break
def get_parent_for_child(self, child_id: str) -> Optional[str]:
"""根据子块 ID 获取父块 ID"""
return self.child_to_parent.get(child_id)
def get_children_for_parent(self, parent_id: str) -> List[str]:
"""根据父块 ID 获取所有子块 ID"""
return self.parent_to_children.get(parent_id, [])
return self.parent_to_children.get(parent_id, [])

View File

@@ -1,15 +0,0 @@
#!/usr/bin/env python3
"""统一入口:设置路径后运行测试"""
import sys
from pathlib import Path
from dotenv import load_dotenv
# 路径设置 - 只添加 backend 目录
project_root = Path(__file__).resolve().parent.parent
backend_path = project_root
sys.path.insert(0, str(backend_path))
load_dotenv(project_root / ".env")
if __name__ == "__main__":
from tools.test import test_tavily_search
test_tavily_search.main()

View File

@@ -2,29 +2,24 @@
"""
主图完整测试 - 覆盖各个分支
"""
import sys
import asyncio
from pathlib import Path
from dotenv import load_dotenv
# 添加 backend 到路径
sys.path.insert(0, str(Path(__file__).parent.parent.parent / "backend"))
from app.main_graph.state import MainGraphState, CurrentAction
from app.main_graph.utils.main_graph_builder import build_react_main_graph
from app.model_services.chat_services import get_all_chat_services
from app.main_graph.tools.graph_tools import AVAILABLE_TOOLS
from app.main_graph.utils.rag_initializer import init_rag_tool
from backend.app.main_graph.state import MainGraphState, CurrentAction
from backend.app.main_graph.main_graph_builder import build_react_main_graph
from backend.app.model_services.chat_services import get_all_chat_services
from backend.app.main_graph.tools.graph_tools import AVAILABLE_TOOLS
from backend.app.main_graph.utils.rag_initializer import init_rag_tool
# ========== 测试用例配置 ==========
TEST_CASES = [
# # 测试1: 简单闲聊 - 应该走 fast_chitchat
# {
# "name": "闲聊测试",
# "query": "你好!",
# "description": "测试快速闲聊分支"
# },
# 测试1: 简单闲聊 - 应该走 fast_chitchat
{
"name": "闲聊测试",
"query": "你好!",
"description": "测试快速闲聊分支"
},
# 测试2: 知识查询 - 应该走 fast_rag然后可能升级到 react
{
"name": "知识查询测试",
@@ -37,12 +32,12 @@ TEST_CASES = [
"query": "请帮我分析如果我有10万元想要在一年内获得15%的收益,有哪些低风险的投资方案?",
"description": "测试 React 循环推理分支"
},
# # 测试4: 需要工具调用的问题
# {
# "name": "工具调用测试",
# "query": "搜索一下今天的天气怎么样",
# "description": "测试工具调用分支"
# },
# 测试4: 需要工具调用的问题
{
"name": "联网工具调用测试",
"query": "搜索一下今天的天气怎么样",
"description": "测试工具调用分支"
},
# 测试5: 带记忆的对话
{
"name": "记忆测试",
@@ -64,22 +59,18 @@ async def setup_test_environment():
if not chat_services:
raise RuntimeError("没有可用的 LLM 服务")
llm = list(chat_services.values())[0]
print(f"✓ 使用 LLM: {list(chat_services.keys())[0]}")
print(f"✓ 可用模型: {list(chat_services.keys())}")
# 初始化 RAG 工具
def create_local_llm():
return llm
rag_tool = await init_rag_tool(create_local_llm)
rag_tool = await init_rag_tool()
tools = AVAILABLE_TOOLS.copy()
if rag_tool:
tools.append(rag_tool)
print(f"✓ RAG 工具初始化成功")
# 构建图
# 构建图(使用新的 API: chat_services 而不是 llm
graph = build_react_main_graph(
llm=llm,
chat_services=chat_services,
tools=tools,
use_hybrid_router=True
).compile()

View File

@@ -10,16 +10,22 @@ from pathlib import Path
# 路径设置
PROJECT_ROOT = Path(__file__).parent.parent
BACKEND_DIR = PROJECT_ROOT / "backend"
sys.path.insert(0, str(PROJECT_ROOT))
sys.path.insert(0, str(BACKEND_DIR))
import warnings
# 抑制 WebSocket 弃用警告
warnings.filterwarnings("ignore", category=DeprecationWarning, module="websockets")
warnings.filterwarnings("ignore", category=DeprecationWarning, module="uvicorn.protocols.websockets")
from dotenv import load_dotenv
load_dotenv(PROJECT_ROOT / ".env")
import asyncio
from backend.app.agent.agent_service import AIAgentService
from backend.app.agent.agent_service import AIAgentService, create_serde
from backend.app.config import DB_URI
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
async def visualize_graph():
"""可视化 LangGraph 结构"""
print("=" * 80)
@@ -28,9 +34,7 @@ async def visualize_graph():
print(f"项目根目录: {PROJECT_ROOT}")
print(f"Backend 目录: {BACKEND_DIR}")
async with AsyncPostgresSaver.from_conn_string(DB_URI) as checkpointer:
async with AsyncPostgresSaver.from_conn_string(DB_URI, serde=create_serde()) as checkpointer:
await checkpointer.setup()
# 创建服务实例