This commit is contained in:
@@ -1,198 +1,88 @@
|
||||
"""
|
||||
AI Agent 服务类 - 完全简化版本!
|
||||
按照指南实现,不用 stream_mode="messages" 避免重复 token!
|
||||
AI Agent 服务类
|
||||
"""
|
||||
|
||||
import json
|
||||
import asyncio
|
||||
from typing import AsyncGenerator, Dict, Any, Optional, Tuple
|
||||
from typing import AsyncGenerator, Dict, Any
|
||||
|
||||
# LangGraph 序列化器(修复 checkpoint 反序列化警告)
|
||||
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
|
||||
|
||||
# 本地模块
|
||||
from backend.app.model_services import get_cached_chat_services
|
||||
from backend.app.main_graph.main_graph_builder import build_agent_graph
|
||||
from backend.app.logger import debug, info, warning, error
|
||||
from backend.app.main_graph.state import AgentState
|
||||
from .stream_context import set_stream_queue
|
||||
from backend.app.logger import info
|
||||
from backend.app.memory.mem0_client import Mem0Client
|
||||
|
||||
from .service_config import ServiceConfig
|
||||
from .stream_handler import run_graph_stream
|
||||
|
||||
|
||||
class AIAgentService:
|
||||
def __init__(self, checkpointer):
|
||||
self.checkpointer = checkpointer
|
||||
self.graph = None
|
||||
self.chat_services = None
|
||||
# Mem0 客户端
|
||||
self.config: ServiceConfig = None
|
||||
self.mem0_client = None
|
||||
|
||||
async def initialize(self):
|
||||
# 0. 初始化 Mem0 客户端
|
||||
from ..memory.mem0_client import Mem0Client
|
||||
async def initialize(self) -> "AIAgentService":
|
||||
"""初始化 Agent 服务"""
|
||||
self.mem0_client = Mem0Client()
|
||||
|
||||
# 1. 获取缓存的模型字典
|
||||
|
||||
self.chat_services = get_cached_chat_services()
|
||||
info(f"✅ 加载了 {len(self.chat_services)} 个可用模型: {list(self.chat_services.keys())}")
|
||||
|
||||
# 2. 构建图
|
||||
info(f"🔄 构建 Agent 图...")
|
||||
|
||||
graph_builder = build_agent_graph(
|
||||
chat_services=self.chat_services,
|
||||
mem0_client=self.mem0_client
|
||||
)
|
||||
|
||||
# 编译图
|
||||
self.graph = graph_builder.compile(checkpointer=self.checkpointer)
|
||||
|
||||
self.config = ServiceConfig(self.chat_services)
|
||||
info(f"✅ Agent 图初始化完成")
|
||||
|
||||
|
||||
return self
|
||||
|
||||
def _resolve_model(self, model: str) -> str:
|
||||
"""
|
||||
解析并验证模型名称,不可用时回退到第一个可用模型
|
||||
|
||||
Args:
|
||||
model: 目标模型名称
|
||||
|
||||
Returns:
|
||||
实际使用的模型名称
|
||||
"""
|
||||
if not model or model not in self.chat_services:
|
||||
fallback = next(iter(self.chat_services.keys()))
|
||||
warning(f"模型 '{model}' 不可用,回退到 '{fallback}'")
|
||||
return fallback
|
||||
return model
|
||||
|
||||
def _build_invocation(
|
||||
def _resolve_and_build(
|
||||
self, message: str, thread_id: str, model: str, user_id: str
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
"""
|
||||
构建图调用所需的 config 和 input_state
|
||||
|
||||
Args:
|
||||
message: 用户消息
|
||||
thread_id: 会话 ID
|
||||
model: 模型名称
|
||||
user_id: 用户 ID
|
||||
|
||||
Returns:
|
||||
(config, input_state) 元组
|
||||
"""
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
config = {
|
||||
"configurable": {
|
||||
"thread_id": thread_id,
|
||||
},
|
||||
"metadata": {"user_id": user_id}
|
||||
}
|
||||
|
||||
input_state = {
|
||||
"messages": [HumanMessage(content=message)],
|
||||
"user_id": user_id,
|
||||
}
|
||||
return config, input_state
|
||||
):
|
||||
"""解析模型并构建调用参数"""
|
||||
resolved_model = self.config.resolve_model(model)
|
||||
return resolved_model, self.config.build_invocation(
|
||||
message, thread_id, resolved_model, user_id
|
||||
)
|
||||
|
||||
async def process_message(
|
||||
self, message: str, thread_id: str, model: str = "", user_id: str = "default_user"
|
||||
) -> dict:
|
||||
"""处理用户消息,返回包含回复、token统计和耗时的字典"""
|
||||
# 解析模型名称
|
||||
resolved_model = self._resolve_model(model)
|
||||
|
||||
# 构建调用参数
|
||||
config, input_state = self._build_invocation(message, thread_id, resolved_model, user_id)
|
||||
resolved_model, (config, input_state) = self._resolve_and_build(
|
||||
message, thread_id, model, user_id
|
||||
)
|
||||
|
||||
result = await self.graph.ainvoke(input_state, config=config)
|
||||
|
||||
# 优先使用 final_reply(finalize 节点返回)
|
||||
reply = result.get("final_reply", "")
|
||||
if not reply and result.get("messages"):
|
||||
reply = result["messages"][-1].content
|
||||
|
||||
token_usage = result.get("last_token_usage", {})
|
||||
elapsed_time = result.get("last_elapsed_time", 0.0)
|
||||
|
||||
# 获取元数据
|
||||
metadata = result.get("metadata", {})
|
||||
|
||||
return {
|
||||
"reply": reply,
|
||||
"token_usage": token_usage,
|
||||
"elapsed_time": elapsed_time,
|
||||
"token_usage": result.get("last_token_usage", {}),
|
||||
"elapsed_time": result.get("last_elapsed_time", 0.0),
|
||||
"model_used": resolved_model,
|
||||
"metadata": metadata
|
||||
"metadata": result.get("metadata", {}),
|
||||
}
|
||||
|
||||
async def process_message_stream(
|
||||
self, message: str, thread_id: str, model: str = "", user_id: str = "default_user"
|
||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""流式处理消息 - 完全简化!"""
|
||||
# 解析模型名称
|
||||
resolved_model = self._resolve_model(model)
|
||||
|
||||
# 构建调用参数
|
||||
config, input_state = self._build_invocation(message, thread_id, resolved_model, user_id)
|
||||
"""流式处理消息"""
|
||||
resolved_model, (config, input_state) = self._resolve_and_build(
|
||||
message, thread_id, model, user_id
|
||||
)
|
||||
|
||||
info(f"🚀 开始执行 Agent 图,指定模型: {resolved_model}")
|
||||
actual_model_used = resolved_model
|
||||
|
||||
# 创建 token 队列
|
||||
queue = asyncio.Queue()
|
||||
set_stream_queue(queue) # 设置上下文变量
|
||||
|
||||
async def run_graph():
|
||||
"""后台任务:运行 graph,流式事件都从 agent 节点内部发送!"""
|
||||
try:
|
||||
info(f"📡 开始调用 graph.astream()...")
|
||||
|
||||
# 注意:只用 stream_mode=["updates"],不要 "messages"!避免重复 token!
|
||||
async for _ in self.graph.astream(
|
||||
input_state,
|
||||
config=config,
|
||||
stream_mode=["updates"],
|
||||
version="v2",
|
||||
subgraphs=True
|
||||
):
|
||||
# 流式事件都从 agent.py 节点内部通过队列发送了
|
||||
# 这里不需要再发送任何事件
|
||||
pass
|
||||
except Exception as e:
|
||||
error(f"❌ 执行图时出错: {e}")
|
||||
import traceback
|
||||
error(f"📋 堆栈: {traceback.format_exc()}")
|
||||
await queue.put({"type": "error", "message": str(e)})
|
||||
finally:
|
||||
await queue.put(None) # 结束哨兵
|
||||
|
||||
# 启动后台任务
|
||||
bg_task = asyncio.create_task(run_graph())
|
||||
|
||||
try:
|
||||
while True:
|
||||
event = await queue.get()
|
||||
if event is None:
|
||||
break
|
||||
async for event in run_graph_stream(self.graph, input_state, config):
|
||||
if event.get("type") != "done":
|
||||
yield event
|
||||
|
||||
except GeneratorExit:
|
||||
# 客户端断开连接,取消后台任务
|
||||
info("⚠️ GeneratorExit,取消后台任务")
|
||||
bg_task.cancel()
|
||||
raise
|
||||
finally:
|
||||
# 保证任务被清理
|
||||
if not bg_task.done():
|
||||
info("⏹️ 清理后台任务")
|
||||
bg_task.cancel()
|
||||
try:
|
||||
await bg_task
|
||||
except asyncio.CancelledError:
|
||||
info("✅ 后台任务已取消")
|
||||
|
||||
# 发送结束事件,保证前端平稳关闭
|
||||
yield {
|
||||
"type": "done",
|
||||
"model_used": actual_model_used
|
||||
}
|
||||
else:
|
||||
yield {**event, "model_used": resolved_model}
|
||||
|
||||
46
backend/app/agent/service_config.py
Normal file
46
backend/app/agent/service_config.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""
|
||||
Agent Service 配置模块 - 配置构建和解析
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Tuple, Optional
|
||||
from langchain_core.messages import HumanMessage
|
||||
from backend.app.logger import warning
|
||||
|
||||
|
||||
class ServiceConfig:
|
||||
"""配置构建器"""
|
||||
|
||||
def __init__(self, chat_services: dict):
|
||||
self.chat_services = chat_services
|
||||
|
||||
def resolve_model(self, model: Optional[str]) -> str:
|
||||
"""
|
||||
解析并验证模型名称,不可用时回退到第一个可用模型
|
||||
"""
|
||||
if not model or model not in self.chat_services:
|
||||
fallback = next(iter(self.chat_services.keys()))
|
||||
warning(f"模型 '{model}' 不可用,回退到 '{fallback}'")
|
||||
return fallback
|
||||
return model
|
||||
|
||||
def build_invocation(
|
||||
self, message: str, thread_id: str, model: str, user_id: str
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
"""
|
||||
构建图调用所需的 config 和 input_state
|
||||
|
||||
Returns:
|
||||
(config, input_state) 元组
|
||||
"""
|
||||
config = {
|
||||
"configurable": {
|
||||
"thread_id": thread_id,
|
||||
},
|
||||
"metadata": {"user_id": user_id}
|
||||
}
|
||||
|
||||
input_state = {
|
||||
"messages": [HumanMessage(content=message)],
|
||||
"user_id": user_id,
|
||||
}
|
||||
return config, input_state
|
||||
78
backend/app/agent/stream_handler.py
Normal file
78
backend/app/agent/stream_handler.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""
|
||||
流式处理模块 - 处理 Agent 执行的流式输出
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import AsyncGenerator, Dict, Any
|
||||
|
||||
from backend.app.logger import info, error
|
||||
from .stream_context import set_stream_queue
|
||||
|
||||
|
||||
async def run_graph_stream(
|
||||
graph,
|
||||
input_state: Dict[str, Any],
|
||||
config: Dict[str, Any],
|
||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""
|
||||
运行图并通过队列流式输出事件
|
||||
|
||||
Args:
|
||||
graph: 编译后的 LangGraph
|
||||
input_state: 输入状态
|
||||
config: 配置
|
||||
|
||||
Yields:
|
||||
流式事件
|
||||
"""
|
||||
queue: asyncio.Queue = asyncio.Queue()
|
||||
set_stream_queue(queue)
|
||||
|
||||
async def run_graph():
|
||||
"""后台任务:运行 graph"""
|
||||
try:
|
||||
info(f"📡 开始调用 graph.astream()...")
|
||||
async for _ in graph.astream(
|
||||
input_state,
|
||||
config=config,
|
||||
stream_mode=["updates"],
|
||||
version="v2",
|
||||
subgraphs=True
|
||||
):
|
||||
# 流式事件都从 agent.py 节点内部通过队列发送了
|
||||
pass
|
||||
except Exception as e:
|
||||
error(f"❌ 执行图时出错: {e}")
|
||||
import traceback
|
||||
error(f"📋 堆栈: {traceback.format_exc()}")
|
||||
await queue.put({"type": "error", "message": str(e)})
|
||||
finally:
|
||||
await queue.put(None) # 结束哨兵
|
||||
|
||||
# 启动后台任务
|
||||
bg_task = asyncio.create_task(run_graph())
|
||||
|
||||
try:
|
||||
while True:
|
||||
event = await queue.get()
|
||||
if event is None:
|
||||
break
|
||||
yield event
|
||||
|
||||
except GeneratorExit:
|
||||
info("⚠️ GeneratorExit,取消后台任务")
|
||||
bg_task.cancel()
|
||||
raise
|
||||
finally:
|
||||
await _cleanup_task(bg_task)
|
||||
|
||||
|
||||
async def _cleanup_task(bg_task: asyncio.Task) -> None:
|
||||
"""清理后台任务"""
|
||||
if not bg_task.done():
|
||||
info("⏹️ 清理后台任务")
|
||||
bg_task.cancel()
|
||||
try:
|
||||
await bg_task
|
||||
except asyncio.CancelledError:
|
||||
info("✅ 后台任务已取消")
|
||||
@@ -1,73 +0,0 @@
|
||||
# app/rag_initializer.py
|
||||
from ...rag.tools import create_rag_tool
|
||||
from ...rag.retriever import create_parent_hybrid_retriever
|
||||
from ...model_services import get_embedding_service
|
||||
from backend.app.logger import info, warning
|
||||
import sys
|
||||
|
||||
# 全局 RAG 工具
|
||||
_rag_tool = None
|
||||
_initialized = False
|
||||
|
||||
|
||||
def get_rag_tool() -> callable:
|
||||
"""获取全局 RAG 工具"""
|
||||
return _rag_tool
|
||||
|
||||
|
||||
def is_initialized() -> bool:
|
||||
"""检查是否已初始化"""
|
||||
return _initialized
|
||||
|
||||
|
||||
async def init_rag_tool(force: bool = False):
|
||||
"""
|
||||
初始化 RAG 工具(注册到模块级变量,内部获取所需服务)
|
||||
|
||||
Args:
|
||||
force: 是否强制重新初始化
|
||||
|
||||
Returns:
|
||||
RAG 工具(@tool 装饰函数)或 None
|
||||
"""
|
||||
global _rag_tool, _initialized
|
||||
|
||||
# 防止重复初始化
|
||||
if _initialized and not force:
|
||||
info("[RAG] 已初始化,跳过")
|
||||
return _rag_tool
|
||||
|
||||
try:
|
||||
from backend.app.model_services.chat_services import get_chat_service
|
||||
|
||||
info("🔄 正在初始化 RAG 检索系统...")
|
||||
embeddings = get_embedding_service()
|
||||
retriever = create_parent_hybrid_retriever(
|
||||
collection_name="rag_documents",
|
||||
search_k=5,
|
||||
embeddings=embeddings,
|
||||
)
|
||||
rewrite_llm = get_chat_service()
|
||||
|
||||
rag_tool = create_rag_tool(
|
||||
retriever=retriever,
|
||||
llm=rewrite_llm,
|
||||
num_queries=3,
|
||||
rerank_top_n=5,
|
||||
)
|
||||
|
||||
_rag_tool = rag_tool
|
||||
_initialized = True
|
||||
info(f"✅ RAG 检索工具初始化成功 (id={id(rag_tool)})")
|
||||
return rag_tool
|
||||
|
||||
except Exception as e:
|
||||
warning(f"⚠️ RAG 检索工具初始化失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def reset():
|
||||
"""重置(用于测试)"""
|
||||
global _rag_tool, _initialized
|
||||
_rag_tool = None
|
||||
_initialized = False
|
||||
@@ -3,12 +3,11 @@
|
||||
Web Search Public Utility - Free, no API Key, using DuckDuckGo
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Optional
|
||||
from typing import List, Optional
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
import requests
|
||||
import warnings
|
||||
import re
|
||||
|
||||
from backend.app.logger import info
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -44,47 +43,31 @@ class WebSearchTool:
|
||||
"""
|
||||
num_results = max_results or self.max_results
|
||||
|
||||
# 方式 1: Tavily (需要 API Key,质量最高)
|
||||
# 尝试搜索方式,按优先级
|
||||
result = self._try_tavily(query, num_results)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
result = self._try_ddgs(query, num_results)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
# 兜底方案
|
||||
return self._get_mock_results(query, num_results)
|
||||
|
||||
def _try_tavily(self, query: str, max_results: int) -> Optional[List[SearchResult]]:
|
||||
"""尝试 Tavily API 搜索"""
|
||||
try:
|
||||
return self._search_tavily(query, num_results)
|
||||
return self._search_tavily(query, max_results)
|
||||
except ImportError:
|
||||
print("[WebSearch] tavily 未安装,尝试其他搜索方式")
|
||||
info("[WebSearch] tavily 未安装")
|
||||
except Exception as e:
|
||||
if "API_KEY" in str(e) or "未配置" in str(e):
|
||||
print(f"[WebSearch] Tavily API Key 未配置: {e}")
|
||||
error_msg = str(e)
|
||||
if "API_KEY" in error_msg or "未配置" in error_msg:
|
||||
info(f"[WebSearch] Tavily API Key 未配置")
|
||||
else:
|
||||
print(f"[WebSearch] Tavily 搜索失败: {e}")
|
||||
|
||||
# 方式 2: 尝试用 ddgs 包
|
||||
try:
|
||||
from ddgs import DDGS
|
||||
print(f"[WebSearch] 使用 ddgs 搜索: {query}")
|
||||
with DDGS() as ddgs:
|
||||
results = list(ddgs.text(query, max_results=num_results))
|
||||
if results:
|
||||
search_results = []
|
||||
for r in results:
|
||||
search_results.append(SearchResult(
|
||||
title=r.get("title", ""),
|
||||
url=r.get("href", ""),
|
||||
snippet=r.get("body", ""),
|
||||
source="DuckDuckGo"
|
||||
))
|
||||
print(f"[WebSearch] ddgs 返回 {len(search_results)} 条结果")
|
||||
return search_results
|
||||
except ImportError:
|
||||
print("[WebSearch] ddgs 未安装,尝试 duckduckgo-search")
|
||||
except Exception as e:
|
||||
print(f"[WebSearch] ddgs 搜索失败: {e}")
|
||||
|
||||
# 方式 3: 尝试用简单 HTTP 请求
|
||||
try:
|
||||
return self._search_http(query, num_results)
|
||||
except Exception as e:
|
||||
print(f"[WebSearch] HTTP 搜索也失败: {e}")
|
||||
|
||||
# 方式 4: 返回模拟数据作为最后兜底
|
||||
return self._search_mock(query, num_results)
|
||||
info(f"[WebSearch] Tavily 搜索失败: {e}")
|
||||
return None
|
||||
|
||||
def _search_tavily(self, query: str, max_results: int) -> List[SearchResult]:
|
||||
"""使用 Tavily API 搜索"""
|
||||
@@ -111,56 +94,40 @@ class WebSearchTool:
|
||||
source="Tavily"
|
||||
))
|
||||
|
||||
print(f"[WebSearch] Tavily 返回 {len(results)} 条结果")
|
||||
info(f"[WebSearch] Tavily 返回 {len(results)} 条结果")
|
||||
return results
|
||||
|
||||
def _search_http(self, query: str, max_results: int) -> List[SearchResult]:
|
||||
"""用简单 HTTP 请求搜索(备用方案)- 尝试多个国内源"""
|
||||
print(f"[WebSearch] 尝试 HTTP 搜索")
|
||||
|
||||
# 方式 1: 尝试百度搜索(简单方式)
|
||||
def _try_ddgs(self, query: str, max_results: int) -> Optional[List[SearchResult]]:
|
||||
"""尝试 DuckDuckGo 搜索"""
|
||||
try:
|
||||
return self._search_baidu(query, max_results)
|
||||
except Exception as e:
|
||||
print(f"[WebSearch] 百度搜索失败: {e}")
|
||||
|
||||
# 方式 2: 返回模拟数据
|
||||
return self._search_mock(query, max_results)
|
||||
from ddgs import DDGS
|
||||
|
||||
def _search_baidu(self, query: str, max_results: int) -> List[SearchResult]:
|
||||
"""尝试百度搜索"""
|
||||
import requests
|
||||
from urllib.parse import quote
|
||||
|
||||
url = f"https://www.baidu.com/s?wd={quote(query)}"
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.get(url, headers=headers, timeout=10)
|
||||
response.raise_for_status()
|
||||
|
||||
# 简单解析百度搜索结果(简化版)
|
||||
results = []
|
||||
# 这里只是示意,真实百度搜索需要更复杂的解析
|
||||
results.append(SearchResult(
|
||||
title=f"百度搜索: {query}",
|
||||
url=url,
|
||||
snippet="如需要真实搜索结果,请考虑使用百度搜索 API",
|
||||
source="百度"
|
||||
))
|
||||
return results
|
||||
except Exception as e:
|
||||
print(f"[WebSearch] 百度搜索也失败: {e}")
|
||||
raise
|
||||
with DDGS() as ddgs:
|
||||
for r in ddgs.text(query, max_results=max_results):
|
||||
results.append(SearchResult(
|
||||
title=r.get("title", ""),
|
||||
url=r.get("href", ""),
|
||||
snippet=r.get("body", ""),
|
||||
source="DuckDuckGo"
|
||||
))
|
||||
|
||||
def _search_mock(self, query: str, max_results: Optional[int] = None) -> List[SearchResult]:
|
||||
"""模拟搜索结果(兜底方案)"""
|
||||
print(f"[WebSearch] 使用模拟搜索结果 (查询: {query})")
|
||||
|
||||
# 根据查询内容生成更有意义的模拟结果
|
||||
mock_templates = [
|
||||
if results:
|
||||
info(f"[WebSearch] ddgs 返回 {len(results)} 条结果")
|
||||
return results
|
||||
|
||||
except ImportError:
|
||||
info("[WebSearch] ddgs 未安装")
|
||||
except Exception as e:
|
||||
info(f"[WebSearch] ddgs 搜索失败: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def _get_mock_results(self, query: str, max_results: Optional[int] = None) -> List[SearchResult]:
|
||||
"""获取模拟搜索结果(兜底方案)"""
|
||||
info(f"[WebSearch] 使用模拟搜索结果")
|
||||
|
||||
templates = [
|
||||
{
|
||||
"title": f"关于「{query}」的相关介绍",
|
||||
"snippet": "这是模拟结果。如需真实搜索,请检查容器网络连接或配置代理。",
|
||||
@@ -177,50 +144,48 @@ class WebSearchTool:
|
||||
"url": "https://example.com/more"
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
num = max_results or self.max_results
|
||||
results = []
|
||||
|
||||
for i, template in enumerate(mock_templates[:num]):
|
||||
|
||||
for template in templates[:num]:
|
||||
results.append(SearchResult(
|
||||
title=template["title"],
|
||||
url=template["url"],
|
||||
snippet=template["snippet"],
|
||||
source="模拟数据"
|
||||
))
|
||||
|
||||
|
||||
return results
|
||||
|
||||
def format_search_results(self, results: List[SearchResult]) -> str:
|
||||
"""
|
||||
格式化搜索结果(带引用溯源)
|
||||
|
||||
|
||||
Args:
|
||||
results: 搜索结果列表
|
||||
|
||||
|
||||
Returns:
|
||||
格式化后的 Markdown 文本
|
||||
"""
|
||||
if not results:
|
||||
return "未找到相关搜索结果"
|
||||
|
||||
lines = []
|
||||
lines.append("## 🔍 联网搜索结果\n")
|
||||
|
||||
|
||||
lines = ["## 🔍 联网搜索结果\n"]
|
||||
|
||||
for idx, result in enumerate(results, 1):
|
||||
lines.append(f"### [{idx}] {result.title}")
|
||||
lines.append(f"- 🔗 来源:[{result.url}]({result.url})")
|
||||
lines.append(f"- 📝 摘要:{result.snippet}")
|
||||
lines.append(f"- 📅 时间:{result.timestamp.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
lines.append("")
|
||||
|
||||
# 添加引用溯源说明
|
||||
|
||||
lines.append("---")
|
||||
lines.append("💡 **引用溯源说明**:")
|
||||
lines.append("- 以上搜索结果均标注了来源链接")
|
||||
lines.append("- 使用方括号数字标识引用(如 [1]、[2])")
|
||||
lines.append("- 可通过链接追溯原始信息")
|
||||
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
@@ -239,11 +204,11 @@ def get_web_search_tool() -> WebSearchTool:
|
||||
def web_search(query: str, max_results: int = 5) -> str:
|
||||
"""
|
||||
便捷函数:联网搜索并返回格式化结果
|
||||
|
||||
|
||||
Args:
|
||||
query: 搜索关键词
|
||||
max_results: 返回结果数量
|
||||
|
||||
|
||||
Returns:
|
||||
格式化后的搜索结果文本
|
||||
"""
|
||||
|
||||
@@ -1,73 +0,0 @@
|
||||
"""
|
||||
RAG 工具模块(完全异步)
|
||||
|
||||
将检索功能封装为 LangChain Tool,供 Agent 调用。
|
||||
采用固定流水线:多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档。
|
||||
|
||||
默认使用混合检索(稠密+BM25稀疏)+ 父子文档模式。
|
||||
"""
|
||||
from typing import Callable, Optional
|
||||
from langchain_core.tools import tool
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from ..rag.pipeline import RAGPipeline, create_rag_pipeline
|
||||
|
||||
|
||||
def create_rag_tool(
|
||||
retriever: Optional[BaseRetriever] = None,
|
||||
llm: Optional[BaseLanguageModel] = "default_small",
|
||||
num_queries: int = 3,
|
||||
rerank_top_n: int = 5,
|
||||
collection_name: str = "rag_documents",
|
||||
) -> Callable:
|
||||
"""
|
||||
创建一个配置好的 RAG 检索工具(完全异步)。
|
||||
|
||||
默认使用混合检索(稠密+BM25稀疏)+ 父子文档模式。
|
||||
|
||||
Args:
|
||||
retriever: 基础检索器对象(可选,不提供则自动创建)
|
||||
llm: 用于生成多路查询的语言模型。
|
||||
- "default_small": (默认) 使用小模型(本地 + DeepSeek)
|
||||
- None / False: 不做查询改写
|
||||
- BaseLanguageModel 实例: 自定义模型
|
||||
num_queries: 生成的查询变体数量
|
||||
rerank_top_n: 最终返回的文档数量
|
||||
collection_name: Qdrant 集合名称
|
||||
|
||||
Returns:
|
||||
Async LangChain Tool 函数
|
||||
"""
|
||||
pipeline = RAGPipeline(
|
||||
retriever=retriever,
|
||||
llm=llm,
|
||||
num_queries=num_queries,
|
||||
rerank_top_n=rerank_top_n,
|
||||
collection_name=collection_name,
|
||||
)
|
||||
|
||||
@tool
|
||||
async def search_knowledge_base(query: str) -> str:
|
||||
"""
|
||||
在知识库中搜索与查询相关的文档片段(完全异步)。
|
||||
|
||||
使用混合检索(稠密向量语义 + BM25 关键词)+ 父子文档模式,
|
||||
检索效果最优。
|
||||
|
||||
Args:
|
||||
query: 用户提出的问题或查询字符串
|
||||
|
||||
Returns:
|
||||
格式化后的相关文档内容
|
||||
"""
|
||||
try:
|
||||
documents = await pipeline.aretrieve(query)
|
||||
if not documents:
|
||||
return f"在知识库 '{collection_name}' 中未找到与 '{query}' 相关的信息。"
|
||||
|
||||
context = pipeline.format_context(documents)
|
||||
return context
|
||||
except Exception as e:
|
||||
return f"检索过程中发生错误: {str(e)}"
|
||||
|
||||
return search_knowledge_base
|
||||
@@ -2,124 +2,17 @@
|
||||
Agent Tools - 所有工具统一定义
|
||||
"""
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from backend.app.logger import info
|
||||
|
||||
# ========== RAG ==========
|
||||
|
||||
_rag_pipeline = None
|
||||
|
||||
|
||||
def _get_rag_pipeline():
|
||||
global _rag_pipeline
|
||||
if _rag_pipeline is None:
|
||||
from backend.app.rag.pipeline import RAGPipeline
|
||||
_rag_pipeline = RAGPipeline(
|
||||
num_queries=3,
|
||||
rerank_top_n=5,
|
||||
use_rerank=True,
|
||||
return_parent_docs=True,
|
||||
)
|
||||
return _rag_pipeline
|
||||
|
||||
|
||||
@tool
|
||||
async def rag_search(query: str) -> str:
|
||||
"""
|
||||
检索知识库获取相关信息
|
||||
|
||||
Returns:
|
||||
包含检索结果和置信度的结构化回复,格式:
|
||||
- 内容:检索到的相关信息
|
||||
- 置信度评估:基于向量相似度、重排分数、LLM判断的综合评分
|
||||
"""
|
||||
info(f"[Tool] rag_search: {query[:30]}...")
|
||||
try:
|
||||
pipeline = _get_rag_pipeline()
|
||||
# 使用带置信度的检索
|
||||
result = await pipeline.aretrieve_with_confidence(query, original_query=query)
|
||||
|
||||
if not result.content:
|
||||
return "【RAG检索结果】\n未在知识库中找到相关内容。\n置信度:0.0\n建议:可尝试联网搜索获取信息。"
|
||||
|
||||
# 构建包含置信度的回复
|
||||
confidence_desc = "高"
|
||||
if result.confidence < 0.4:
|
||||
confidence_desc = "低"
|
||||
elif result.confidence < 0.6:
|
||||
confidence_desc = "中"
|
||||
|
||||
response = f"""【RAG检索结果】
|
||||
{result.content}
|
||||
|
||||
【置信度评估】
|
||||
- 综合置信度:{result.confidence:.2f}({confidence_desc})
|
||||
- 向量相似度:{result.scores['embedding']:.2f}
|
||||
- 重排分数:{result.scores['rerank']:.2f}
|
||||
- LLM评估:{result.scores['llm']:.2f}
|
||||
|
||||
{'✅ 检索结果可信,可直接使用' if result.is_useful else '⚠️ 检索结果置信度较低,可能需要联网搜索补充'}"""
|
||||
|
||||
info(f"[Tool] rag_search 完成: confidence={result.confidence:.3f}, is_useful={result.is_useful}")
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
info(f"[Tool] rag_search 失败: {e}")
|
||||
return f"【RAG检索失败】\n错误:{str(e)}\n建议:请稍后重试或使用联网搜索"
|
||||
|
||||
|
||||
# ========== 联网搜索 ==========
|
||||
|
||||
@tool
|
||||
def web_search(query: str) -> str:
|
||||
"""联网搜索获取最新信息"""
|
||||
info(f"[Tool] web_search: {query[:30]}...")
|
||||
try:
|
||||
from backend.app.core.web_search import web_search as search_fn
|
||||
return search_fn(query, max_results=5)
|
||||
except Exception as e:
|
||||
info(f"[Tool] web_search 失败: {e}")
|
||||
return f"联网搜索失败: {str(e)}"
|
||||
|
||||
|
||||
# ========== 子图工具 ==========
|
||||
|
||||
async def _call_subgraph(builder_fn, state_cls, query: str) -> str:
|
||||
"""通用子图调用"""
|
||||
try:
|
||||
graph = builder_fn().compile()
|
||||
state = state_cls(user_query=query)
|
||||
result = await graph.ainvoke(state)
|
||||
return result.get("final_result", "执行完成")
|
||||
except Exception as e:
|
||||
info(f"[Tool] 子图调用失败: {e}")
|
||||
return f"执行失败: {str(e)}"
|
||||
|
||||
|
||||
@tool
|
||||
async def contact_lookup(query: str) -> str:
|
||||
"""查询通讯录"""
|
||||
from backend.app.subgraphs.contact.graph import build_contact_subgraph
|
||||
from backend.app.subgraphs.contact.state import ContactState
|
||||
return await _call_subgraph(build_contact_subgraph, ContactState, query)
|
||||
|
||||
|
||||
@tool
|
||||
async def dictionary_lookup(word: str) -> str:
|
||||
"""查询词典/翻译"""
|
||||
from backend.app.subgraphs.dictionary.graph import build_dictionary_subgraph
|
||||
from backend.app.subgraphs.dictionary.state import DictionaryState
|
||||
return await _call_subgraph(build_dictionary_subgraph, DictionaryState, word)
|
||||
|
||||
|
||||
@tool
|
||||
async def news_analysis(topic: str) -> str:
|
||||
"""分析新闻热点"""
|
||||
from backend.app.subgraphs.news_analysis.graph import build_news_analysis_subgraph
|
||||
from backend.app.subgraphs.news_analysis.state import NewsAnalysisState
|
||||
return await _call_subgraph(build_news_analysis_subgraph, NewsAnalysisState, topic)
|
||||
|
||||
|
||||
# ========== 导出 ==========
|
||||
from .rag import rag_search
|
||||
from .web_search import web_search
|
||||
from .subgraph import contact_lookup, dictionary_lookup, news_analysis
|
||||
|
||||
ALL_TOOLS = [rag_search, web_search, contact_lookup, dictionary_lookup, news_analysis]
|
||||
|
||||
__all__ = [
|
||||
"rag_search",
|
||||
"web_search",
|
||||
"contact_lookup",
|
||||
"dictionary_lookup",
|
||||
"news_analysis",
|
||||
"ALL_TOOLS",
|
||||
]
|
||||
|
||||
8
backend/app/tools/base.py
Normal file
8
backend/app/tools/base.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""
|
||||
工具模块配置
|
||||
"""
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from backend.app.logger import info
|
||||
|
||||
__all__ = ["tool", "info"]
|
||||
70
backend/app/tools/rag.py
Normal file
70
backend/app/tools/rag.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""
|
||||
RAG 检索工具
|
||||
"""
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from backend.app.logger import info
|
||||
|
||||
|
||||
_rag_pipeline = None
|
||||
|
||||
|
||||
def _get_rag_pipeline():
|
||||
"""获取或创建 RAG pipeline 单例"""
|
||||
global _rag_pipeline
|
||||
if _rag_pipeline is None:
|
||||
from backend.app.rag.pipeline import RAGPipeline
|
||||
_rag_pipeline = RAGPipeline(
|
||||
num_queries=3,
|
||||
rerank_top_n=5,
|
||||
use_rerank=True,
|
||||
return_parent_docs=True,
|
||||
)
|
||||
return _rag_pipeline
|
||||
|
||||
|
||||
def _format_confidence(result) -> str:
|
||||
"""格式化置信度描述"""
|
||||
if result.confidence < 0.4:
|
||||
return "低"
|
||||
elif result.confidence < 0.6:
|
||||
return "中"
|
||||
return "高"
|
||||
|
||||
|
||||
@tool
|
||||
async def rag_search(query: str) -> str:
|
||||
"""
|
||||
检索知识库获取相关信息
|
||||
|
||||
Returns:
|
||||
包含检索结果和置信度的结构化回复
|
||||
"""
|
||||
info(f"[Tool] rag_search: {query[:30]}...")
|
||||
try:
|
||||
pipeline = _get_rag_pipeline()
|
||||
result = await pipeline.aretrieve_with_confidence(query, original_query=query)
|
||||
|
||||
if not result.content:
|
||||
return "【RAG检索结果】\n未在知识库中找到相关内容。\n置信度:0.0\n建议:可尝试联网搜索获取信息。"
|
||||
|
||||
confidence_desc = _format_confidence(result)
|
||||
is_useful_note = "✅ 检索结果可信,可直接使用" if result.is_useful else "⚠️ 检索结果置信度较低,可能需要联网搜索补充"
|
||||
|
||||
response = f"""【RAG检索结果】
|
||||
{result.content}
|
||||
|
||||
【置信度评估】
|
||||
- 综合置信度:{result.confidence:.2f}({confidence_desc})
|
||||
- 向量相似度:{result.scores['embedding']:.2f}
|
||||
- 重排分数:{result.scores['rerank']:.2f}
|
||||
- LLM评估:{result.scores['llm']:.2f}
|
||||
|
||||
{is_useful_note}"""
|
||||
|
||||
info(f"[Tool] rag_search 完成: confidence={result.confidence:.3f}, is_useful={result.is_useful}")
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
info(f"[Tool] rag_search 失败: {e}")
|
||||
return f"【RAG检索失败】\n错误:{str(e)}\n建议:请稍后重试或使用联网搜索"
|
||||
42
backend/app/tools/subgraph.py
Normal file
42
backend/app/tools/subgraph.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""
|
||||
子图工具 - 通讯录、词典、新闻分析等
|
||||
"""
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from backend.app.logger import info
|
||||
|
||||
|
||||
async def _call_subgraph(builder_fn, state_cls, query: str) -> str:
|
||||
"""通用子图调用"""
|
||||
try:
|
||||
graph = builder_fn().compile()
|
||||
state = state_cls(user_query=query)
|
||||
result = await graph.ainvoke(state)
|
||||
return result.get("final_result", "执行完成")
|
||||
except Exception as e:
|
||||
info(f"[Tool] 子图调用失败: {e}")
|
||||
return f"执行失败: {str(e)}"
|
||||
|
||||
|
||||
@tool
|
||||
async def contact_lookup(query: str) -> str:
|
||||
"""查询通讯录"""
|
||||
from backend.app.subgraphs.contact.graph import build_contact_subgraph
|
||||
from backend.app.subgraphs.contact.state import ContactState
|
||||
return await _call_subgraph(build_contact_subgraph, ContactState, query)
|
||||
|
||||
|
||||
@tool
|
||||
async def dictionary_lookup(word: str) -> str:
|
||||
"""查询词典/翻译"""
|
||||
from backend.app.subgraphs.dictionary.graph import build_dictionary_subgraph
|
||||
from backend.app.subgraphs.dictionary.state import DictionaryState
|
||||
return await _call_subgraph(build_dictionary_subgraph, DictionaryState, word)
|
||||
|
||||
|
||||
@tool
|
||||
async def news_analysis(topic: str) -> str:
|
||||
"""分析新闻热点"""
|
||||
from backend.app.subgraphs.news_analysis.graph import build_news_analysis_subgraph
|
||||
from backend.app.subgraphs.news_analysis.state import NewsAnalysisState
|
||||
return await _call_subgraph(build_news_analysis_subgraph, NewsAnalysisState, topic)
|
||||
18
backend/app/tools/web_search.py
Normal file
18
backend/app/tools/web_search.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""
|
||||
联网搜索工具
|
||||
"""
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from backend.app.logger import info
|
||||
|
||||
|
||||
@tool
|
||||
def web_search(query: str) -> str:
|
||||
"""联网搜索获取最新信息"""
|
||||
info(f"[Tool] web_search: {query[:30]}...")
|
||||
try:
|
||||
from backend.app.core.web_search import web_search as search_fn
|
||||
return search_fn(query, max_results=5)
|
||||
except Exception as e:
|
||||
info(f"[Tool] web_search 失败: {e}")
|
||||
return f"联网搜索失败: {str(e)}"
|
||||
Reference in New Issue
Block a user