This commit is contained in:
142
app/agent.py
142
app/agent.py
@@ -4,6 +4,7 @@ AI Agent 服务类 - 支持多模型动态切换
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
from dotenv import load_dotenv
|
||||
from langchain_community.chat_models import ChatZhipuAI
|
||||
from langchain_openai import ChatOpenAI
|
||||
@@ -41,8 +42,9 @@ class AIAgentService:
|
||||
api_key=api_key,
|
||||
temperature=0.1,
|
||||
max_tokens=4096,
|
||||
timeout=60.0, # 请求超时时间(秒)
|
||||
max_retries=2, # 失败后自动重试次数
|
||||
timeout=120.0, # 增加请求超时时间(秒),原为60秒
|
||||
max_retries=3, # 增加重试次数,原为2次
|
||||
streaming=True, # 确保开启流式输出
|
||||
)
|
||||
|
||||
def _create_deepseek_llm(self):
|
||||
@@ -58,6 +60,7 @@ class AIAgentService:
|
||||
max_tokens=4096,
|
||||
timeout=60.0, # 请求超时时间(秒)
|
||||
max_retries=2, # 失败后自动重试次数
|
||||
streaming=True, # 确保开启流式输出
|
||||
)
|
||||
|
||||
def _create_local_llm(self):
|
||||
@@ -65,7 +68,7 @@ class AIAgentService:
|
||||
# vLLM 服务地址:优先从环境变量读取,适配 Docker、FRP 穿透和本地开发
|
||||
vllm_base_url = os.getenv(
|
||||
"VLLM_BASE_URL",
|
||||
"http://localhost:8081/v1"
|
||||
"http://127.0.0.1:8081/v1"
|
||||
)
|
||||
|
||||
return ChatOpenAI(
|
||||
@@ -74,14 +77,15 @@ class AIAgentService:
|
||||
model="gemma-4-E2B-it",
|
||||
timeout=60.0, # 请求超时时间(秒)
|
||||
max_retries=2, # 失败后自动重试次数
|
||||
streaming=True, # 确保开启流式输出
|
||||
)
|
||||
|
||||
async def initialize(self):
|
||||
"""预编译所有模型的 graph(使用传入的 checkpointer)"""
|
||||
model_configs = {
|
||||
"zhipu": self._create_zhipu_llm,
|
||||
"deepseek": self._create_deepseek_llm,
|
||||
"local": self._create_local_llm,
|
||||
"local": self._create_local_llm, # 本地模型作为第一个
|
||||
"deepseek": self._create_deepseek_llm, # DeepSeek 作为中间
|
||||
"zhipu": self._create_zhipu_llm, # GLM-4.7 作为最后一个
|
||||
}
|
||||
|
||||
for model_name, llm_creator in model_configs.items():
|
||||
@@ -107,7 +111,7 @@ class AIAgentService:
|
||||
|
||||
return self
|
||||
|
||||
async def process_message(self, message: str, thread_id: str, model: str = "zhipu", user_id: str = "default_user") -> dict:
|
||||
async def process_message(self, message: str, thread_id: str, model: str = "local", user_id: str = "default_user") -> dict:
|
||||
"""
|
||||
处理用户消息,返回包含回复、token统计和耗时的字典
|
||||
|
||||
@@ -156,6 +160,28 @@ class AIAgentService:
|
||||
"elapsed_time": elapsed_time
|
||||
}
|
||||
|
||||
def _serialize_value(self, value):
|
||||
"""递归将 LangChain 对象转换为可 JSON 序列化的格式"""
|
||||
if hasattr(value, 'content'):
|
||||
# LangChain 消息对象
|
||||
msg_type = getattr(value, 'type', 'message')
|
||||
return {
|
||||
"role": msg_type,
|
||||
"content": getattr(value, 'content', ''),
|
||||
"additional_kwargs": getattr(value, 'additional_kwargs', {}),
|
||||
"tool_calls": getattr(value, 'tool_calls', [])
|
||||
}
|
||||
elif isinstance(value, dict):
|
||||
return {k: self._serialize_value(v) for k, v in value.items()}
|
||||
elif isinstance(value, (list, tuple)):
|
||||
return [self._serialize_value(item) for item in value]
|
||||
else:
|
||||
try:
|
||||
json.dumps(value)
|
||||
return value
|
||||
except (TypeError, ValueError):
|
||||
return str(value)
|
||||
|
||||
async def process_message_stream(self, message: str, thread_id: str, model_name: str, user_id: str = "default_user"):
|
||||
"""
|
||||
流式处理消息,返回异步生成器
|
||||
@@ -170,10 +196,9 @@ class AIAgentService:
|
||||
字典,包含事件类型和数据
|
||||
"""
|
||||
graph = self.graphs.get(model_name)
|
||||
|
||||
if not graph:
|
||||
warning(f"警告: 模型 '{model_name}' 不可用,使用默认模型")
|
||||
model_name = next(iter(self.graphs.keys()))
|
||||
graph = self.graphs[model_name]
|
||||
raise ValueError(f"模型 '{model_name}' 未找到或未初始化")
|
||||
|
||||
config = {
|
||||
"configurable": {"thread_id": thread_id},
|
||||
@@ -182,36 +207,71 @@ class AIAgentService:
|
||||
input_state = {"messages": [{"role": "user", "content": message}]}
|
||||
context = GraphContext(user_id=user_id)
|
||||
|
||||
# 使用 astream_events 获取流式事件
|
||||
async for event in graph.astream_events(input_state, config=config, context=context, version="v2"):
|
||||
kind = event["event"]
|
||||
|
||||
# 聊天模型流式输出
|
||||
if kind == "on_chat_model_stream":
|
||||
content = event["data"]["chunk"].content
|
||||
if content:
|
||||
yield {"type": "token", "content": content}
|
||||
|
||||
# 工具调用开始
|
||||
elif kind == "on_tool_start":
|
||||
tool_name = event["name"]
|
||||
yield {"type": "tool_start", "tool": tool_name}
|
||||
|
||||
# 工具调用结束
|
||||
elif kind == "on_tool_end":
|
||||
tool_name = event["name"]
|
||||
yield {"type": "tool_end", "tool": tool_name}
|
||||
|
||||
# 链结束,获取最终结果
|
||||
elif kind == "on_chain_end" and event["name"] == "LangGraph":
|
||||
output = event["data"]["output"]
|
||||
reply = output["messages"][-1].content if output.get("messages") else ""
|
||||
token_usage = output.get("last_token_usage", {})
|
||||
elapsed_time = output.get("last_elapsed_time", 0.0)
|
||||
async for chunk in graph.astream(
|
||||
input_state,
|
||||
config=config,
|
||||
context=context,
|
||||
stream_mode=["messages", "updates", "custom"], # 组合多种模式,添加 custom
|
||||
version="v2", # 使用统一的v2格式
|
||||
subgraphs=True # 如果你使用了子图,请开启此项
|
||||
):
|
||||
chunk_type = chunk["type"]
|
||||
processed_event = {}
|
||||
|
||||
# 1. 处理 LLM Token 流 (实现打字机效果)
|
||||
if chunk_type == "messages":
|
||||
message_chunk, metadata = chunk["data"]
|
||||
|
||||
yield {
|
||||
"type": "done",
|
||||
"reply": reply,
|
||||
"token_usage": token_usage,
|
||||
"elapsed_time": elapsed_time
|
||||
# 提取元数据
|
||||
node_name = metadata.get("langgraph_node", "unknown")
|
||||
# 使用 getattr 安全地获取内容,因为 message_chunk 可能不是字符串
|
||||
token_content = getattr(message_chunk, 'content', str(message_chunk))
|
||||
|
||||
# 提取 DeepSeek reasoner 的思考过程 token
|
||||
reasoning_token = ""
|
||||
if hasattr(message_chunk, 'additional_kwargs'):
|
||||
reasoning_token = message_chunk.additional_kwargs.get("reasoning_content", "")
|
||||
|
||||
# [DEBUG] 临时添加:只在 reasoning_token 不为空时打印,方便你直观地看到它
|
||||
if reasoning_token:
|
||||
import logging
|
||||
logging.debug(f"💡 [Reasoning Token 捕获]: {repr(reasoning_token)}")
|
||||
|
||||
processed_event = {
|
||||
"type": "llm_token",
|
||||
"node": node_name,
|
||||
"token": token_content,
|
||||
"reasoning_token": reasoning_token,
|
||||
"metadata": metadata # 可选的元数据
|
||||
}
|
||||
|
||||
# 2. 处理状态更新 (节点执行完成)
|
||||
elif chunk_type == "updates":
|
||||
updates_data = chunk["data"]
|
||||
# 序列化 updates 中的所有数据
|
||||
serialized_data = self._serialize_value(updates_data)
|
||||
processed_event = {
|
||||
"type": "state_update",
|
||||
"data": serialized_data
|
||||
}
|
||||
# 为了兼容前端旧字段,也保留 messages 字段(可选)
|
||||
if "messages" in serialized_data:
|
||||
processed_event["messages"] = serialized_data["messages"]
|
||||
|
||||
# 3. 处理自定义数据 (如果需要)
|
||||
elif chunk_type == "custom":
|
||||
# 自定义事件同样需要序列化
|
||||
serialized_data = self._serialize_value(chunk["data"])
|
||||
processed_event = {
|
||||
"type": "custom",
|
||||
"data": serialized_data
|
||||
}
|
||||
|
||||
# 4. 其他类型(debug, tasks等)按需处理
|
||||
else:
|
||||
# 对于不需要的类型,直接跳过
|
||||
continue
|
||||
|
||||
# 确保事件有数据再发送
|
||||
if processed_event:
|
||||
yield processed_event
|
||||
@@ -25,7 +25,7 @@ load_dotenv()
|
||||
# 优先级:环境变量 DB_URI > Docker 内部服务名 > 本地开发地址
|
||||
DB_URI = os.getenv(
|
||||
"DB_URI",
|
||||
"postgresql://postgres:mysecretpassword@ai-postgres:5432/langgraph_db?sslmode=disable"
|
||||
"postgresql://postgres:huang1998@ai-postgres:5432/langgraph_db?sslmode=disable"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -16,8 +16,9 @@ MEMORY_SUMMARIZE_INTERVAL = int(os.getenv("MEMORY_SUMMARIZE_INTERVAL", "10"))
|
||||
|
||||
# ========== Mem0 记忆层配置 ==========
|
||||
# Qdrant 向量数据库地址
|
||||
QDRANT_URL = os.getenv("QDRANT_URL", "http://localhost:6333")
|
||||
QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333")
|
||||
QDRANT_COLLECTION_NAME = os.getenv("QDRANT_COLLECTION_NAME", "mem0_user_memories")
|
||||
|
||||
# vLLM Embedding 服务地址 (用于 Mem0 的向量化)
|
||||
VLLM_EMBEDDING_URL = os.getenv("VLLM_EMBEDDING_URL", "http://localhost:8082/v1")
|
||||
# llama.cpp Embedding 服务地址 (用于 Mem0 的向量化)
|
||||
LLAMACPP_EMBEDDING_URL = os.getenv("LLAMACPP_EMBEDDING_URL", "http://127.0.0.1:8082/v1")
|
||||
LLAMACPP_API_KEY = os.getenv("LLAMACPP_API_KEY", "your-llamacpp-api-key")
|
||||
@@ -16,6 +16,7 @@ from app.nodes import (
|
||||
should_continue
|
||||
)
|
||||
from app.memory import Mem0Client
|
||||
from app.nodes.finalize import finalize_node
|
||||
|
||||
|
||||
class GraphBuilder:
|
||||
@@ -57,6 +58,7 @@ class GraphBuilder:
|
||||
builder.add_node("llm_call", llm_call_node)
|
||||
builder.add_node("tool_node", tool_call_node)
|
||||
builder.add_node("summarize", summarize_node)
|
||||
builder.add_node("finalize", finalize_node)
|
||||
|
||||
# 添加边
|
||||
builder.add_edge(START, "retrieve_memory")
|
||||
@@ -67,10 +69,11 @@ class GraphBuilder:
|
||||
{
|
||||
"tool_node": "tool_node",
|
||||
"summarize": "summarize",
|
||||
'END': END
|
||||
"finalize": "finalize"
|
||||
}
|
||||
)
|
||||
builder.add_edge("tool_node", "llm_call")
|
||||
builder.add_edge("summarize", END)
|
||||
builder.add_edge("summarize", "finalize")
|
||||
builder.add_edge("finalize", END)
|
||||
|
||||
return builder
|
||||
return builder
|
||||
@@ -28,11 +28,14 @@ class ThreadHistoryService:
|
||||
try:
|
||||
# 查询 checkpoints 表获取用户的线程列表
|
||||
async with self.checkpointer.conn.cursor() as cur:
|
||||
# 查询每个线程的最新 checkpoint 和创建时间
|
||||
# 在较新的 LangGraph 版本中,AsyncPostgresSaver 创建的 checkpoints 表
|
||||
# 没有 created_at 列,而是使用 checkpoint_id 作为时间排序依据。
|
||||
# 我们可以直接按 thread_id 去重,并用 checkpoint_id 排序。
|
||||
# 另外,用户的 metadata 存储在 metadata JSONB 列中。
|
||||
query = """
|
||||
SELECT
|
||||
thread_id,
|
||||
MAX(created_at) as last_updated
|
||||
MAX(checkpoint_id) as last_updated
|
||||
FROM checkpoints
|
||||
WHERE metadata->>'user_id' = %s
|
||||
GROUP BY thread_id
|
||||
@@ -49,17 +52,20 @@ class ThreadHistoryService:
|
||||
# 获取该线程的状态
|
||||
state = await self.checkpointer.aget_tuple({"configurable": {"thread_id": thread_id}})
|
||||
|
||||
if state and state.values:
|
||||
messages = state.values.get("messages", [])
|
||||
summary = self._extract_summary(messages)
|
||||
message_count = len([m for m in messages if hasattr(m, 'type') and m.type in ["human", "ai"]])
|
||||
if state and hasattr(state, 'checkpoint') and isinstance(state.checkpoint, dict):
|
||||
messages = state.checkpoint.get("channel_values", {}).get("messages", [])
|
||||
|
||||
threads.append({
|
||||
"thread_id": thread_id,
|
||||
"last_updated": row['last_updated'].isoformat() if row['last_updated'] else "",
|
||||
"summary": summary,
|
||||
"message_count": message_count
|
||||
})
|
||||
if messages:
|
||||
summary = self._extract_summary(messages)
|
||||
message_count = len([m for m in messages if hasattr(m, 'type') and m.type in ["human", "ai"]])
|
||||
|
||||
threads.append({
|
||||
"thread_id": thread_id,
|
||||
# checkpoint_id 是一个类似于 uuid 的字符串,其中可能包含时间戳信息,也可以直接作为唯一标识
|
||||
"last_updated": row['last_updated'] if row['last_updated'] else "",
|
||||
"summary": summary,
|
||||
"message_count": message_count
|
||||
})
|
||||
|
||||
return threads
|
||||
|
||||
@@ -80,10 +86,13 @@ class ThreadHistoryService:
|
||||
try:
|
||||
state = await self.checkpointer.aget_tuple({"configurable": {"thread_id": thread_id}})
|
||||
|
||||
if state is None or not state.values:
|
||||
if state is None:
|
||||
return []
|
||||
|
||||
messages = state.values.get("messages", [])
|
||||
messages = state.checkpoint.get("channel_values", {}).get("messages", []) if hasattr(state, 'checkpoint') and isinstance(state.checkpoint, dict) else []
|
||||
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
# 转换 LangChain 消息对象为字典
|
||||
result = []
|
||||
|
||||
@@ -3,142 +3,151 @@ Mem0 记忆层客户端封装模块
|
||||
负责 Mem0 的初始化、检索和存储
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
from typing import Optional, List, Dict, Any
|
||||
from mem0 import AsyncMemory
|
||||
|
||||
# 本地模块
|
||||
from app.config import QDRANT_URL, QDRANT_COLLECTION_NAME, VLLM_EMBEDDING_URL
|
||||
from app.config import QDRANT_URL, QDRANT_COLLECTION_NAME, LLAMACPP_EMBEDDING_URL, LLAMACPP_API_KEY
|
||||
from app.logger import info, warning, error
|
||||
|
||||
|
||||
class Mem0Client:
|
||||
"""Mem0 异步客户端封装类"""
|
||||
|
||||
|
||||
def __init__(self, llm_instance):
|
||||
"""
|
||||
初始化 Mem0 客户端
|
||||
|
||||
|
||||
Args:
|
||||
llm_instance: LangChain LLM 实例(用于事实提取)
|
||||
"""
|
||||
self.llm = llm_instance
|
||||
self.mem0: Optional[AsyncMemory] = None
|
||||
self._initialized = False
|
||||
|
||||
|
||||
async def initialize(self):
|
||||
"""异步初始化 Mem0 客户端"""
|
||||
"""异步初始化 Mem0 客户端,并进行实际连接测试"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
try:
|
||||
# 检查 Qdrant 是否可达 (可选)
|
||||
import requests
|
||||
try:
|
||||
resp = requests.get(f"{QDRANT_URL}/collections", timeout=2)
|
||||
if resp.status_code == 200:
|
||||
info(f"✅ Qdrant 服务正常: {QDRANT_URL}")
|
||||
except Exception:
|
||||
warning(f"⚠️ 无法连接到 Qdrant: {QDRANT_URL},Mem0 将尝试自动连接")
|
||||
|
||||
try:
|
||||
# Mem0 配置
|
||||
config = {
|
||||
# 向量存储:复用 Qdrant 实例
|
||||
"vector_store": {
|
||||
"provider": "qdrant",
|
||||
"config": {
|
||||
"url": QDRANT_URL, # 直接使用完整 URL
|
||||
"collection_name": QDRANT_COLLECTION_NAME,
|
||||
"host": QDRANT_URL.split("://")[1].split(":")[0] if "://" in QDRANT_URL else "localhost",
|
||||
"port": int(QDRANT_URL.split(":")[-1]) if ":" in QDRANT_URL.split("://")[-1] else 6333,
|
||||
"embedding_model_dims": 768, # embeddinggemma-300m 输出 768 维
|
||||
"embedding_model_dims": 768,
|
||||
}
|
||||
},
|
||||
# 事实提取 LLM:直接复用传入的 LangChain 实例
|
||||
"llm": {
|
||||
"provider": "langchain",
|
||||
"config": {
|
||||
"model": self.llm # 直接传入 LangChain 模型实例
|
||||
"model": self.llm
|
||||
}
|
||||
},
|
||||
# Embedding:指向 vLLM 服务
|
||||
"embedder": {
|
||||
"provider": "openai",
|
||||
"embedding_dims": 768, # 关键:将维度参数提升到顶层
|
||||
"config": {
|
||||
"model": "google/embeddinggemma-300m",
|
||||
"api_key": "EMPTY",
|
||||
"api_base": VLLM_EMBEDDING_URL,
|
||||
# 注意:不要在此处传递 dimensions 参数,避免与 vLLM v0.7.2 不兼容
|
||||
}
|
||||
"model": "embeddinggemma-300M-Q8_0",
|
||||
"api_key": LLAMACPP_API_KEY,
|
||||
"openai_base_url": LLAMACPP_EMBEDDING_URL,
|
||||
},
|
||||
},
|
||||
"version": "v1.1"
|
||||
}
|
||||
|
||||
self.mem0 = AsyncMemory.from_config(config)
|
||||
self._initialized = True
|
||||
info(f"✅ Mem0 初始化成功 (Embedding: vLLM@8002, Vector: Qdrant, LLM: 复用现有实例)")
|
||||
info("✅ Mem0 配置加载成功,开始连接测试...")
|
||||
|
||||
except Exception as e:
|
||||
error(f"❌ Mem0 初始化失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
# 实际连接测试:调用一次 search 确保 Qdrant 和 Embedding 都可达
|
||||
await asyncio.wait_for(
|
||||
self.mem0.search("ping", user_id="test", limit=1),
|
||||
timeout=60.0
|
||||
)
|
||||
info("✅ Mem0 实际连接测试成功,初始化完成")
|
||||
self._initialized = True
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
error("❌ Mem0 连接测试超时 (10s),请检查 Qdrant 或 Embedding 服务响应")
|
||||
self.mem0 = None
|
||||
|
||||
self._initialized = False
|
||||
except Exception as e:
|
||||
error(f"❌ Mem0 初始化或连接测试失败: {e}")
|
||||
import traceback
|
||||
error(f"详细错误信息:\n{traceback.format_exc()}")
|
||||
self.mem0 = None
|
||||
self._initialized = False
|
||||
|
||||
async def search_memories(self, query: str, user_id: str, limit: int = 5) -> List[str]:
|
||||
"""
|
||||
检索相关记忆
|
||||
|
||||
|
||||
Args:
|
||||
query: 查询文本
|
||||
user_id: 用户 ID
|
||||
limit: 返回结果数量限制
|
||||
|
||||
|
||||
Returns:
|
||||
List[str]: 记忆事实列表
|
||||
"""
|
||||
if not self.mem0:
|
||||
warning("⚠️ Mem0 未初始化,跳过记忆检索")
|
||||
return []
|
||||
|
||||
|
||||
try:
|
||||
memories = await self.mem0.search(query, user_id=user_id, limit=limit)
|
||||
|
||||
memories = await asyncio.wait_for(
|
||||
self.mem0.search(query, user_id=user_id, limit=limit),
|
||||
timeout=30.0
|
||||
)
|
||||
|
||||
if memories and "results" in memories:
|
||||
facts = [m["memory"] for m in memories["results"] if m.get("memory")]
|
||||
if facts:
|
||||
info(f"🔍 [记忆检索] Mem0 返回 {len(facts)} 条记忆")
|
||||
return facts
|
||||
|
||||
|
||||
info("🔍 [记忆检索] 未找到相关记忆")
|
||||
return []
|
||||
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
warning("⚠️ Mem0 检索超时 (30s),跳过本次记忆检索")
|
||||
return []
|
||||
except Exception as e:
|
||||
warning(f"⚠️ Mem0 检索失败: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def add_memories(self, messages: List[Dict[str, str]], user_id: str) -> bool:
|
||||
"""
|
||||
添加记忆(自动提取事实并存储)
|
||||
|
||||
|
||||
Args:
|
||||
messages: 消息列表,格式为 [{"role": "user/assistant/system", "content": "..."}]
|
||||
user_id: 用户 ID
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 是否成功
|
||||
"""
|
||||
if not self.mem0:
|
||||
warning("⚠️ Mem0 未初始化,跳过记忆添加")
|
||||
return False
|
||||
|
||||
|
||||
try:
|
||||
result = await self.mem0.add(
|
||||
messages,
|
||||
user_id=user_id,
|
||||
metadata={"type": "conversation"}
|
||||
await asyncio.wait_for(
|
||||
self.mem0.add(
|
||||
messages,
|
||||
user_id=user_id,
|
||||
metadata={"type": "conversation"}
|
||||
),
|
||||
timeout=60.0
|
||||
)
|
||||
info(f"📝 [记忆添加] 已提交给 Mem0 进行事实提取")
|
||||
info("📝 [记忆添加] 已提交给 Mem0 进行事实提取")
|
||||
return True
|
||||
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
error("❌ Mem0 记忆添加超时 (60s)")
|
||||
return False
|
||||
except Exception as e:
|
||||
error(f"❌ Mem0 记忆添加失败: {e}")
|
||||
return False
|
||||
return False
|
||||
@@ -7,6 +7,7 @@ from app.nodes.llm_call import create_llm_call_node
|
||||
from app.nodes.tool_call import create_tool_call_node
|
||||
from app.nodes.retrieve_memory import create_retrieve_memory_node
|
||||
from app.nodes.summarize import create_summarize_node
|
||||
from app.nodes.finalize import finalize_node
|
||||
|
||||
__all__ = [
|
||||
"should_continue",
|
||||
@@ -14,4 +15,5 @@ __all__ = [
|
||||
"create_tool_call_node",
|
||||
"create_retrieve_memory_node",
|
||||
"create_summarize_node",
|
||||
"finalize_node",
|
||||
]
|
||||
|
||||
47
app/nodes/finalize.py
Normal file
47
app/nodes/finalize.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""
|
||||
完成事件节点模块
|
||||
负责发送完成事件,包含token使用情况和耗时信息
|
||||
"""
|
||||
|
||||
from typing import Any, Dict
|
||||
from langgraph.runtime import Runtime
|
||||
from langgraph.config import get_stream_writer
|
||||
|
||||
# 本地模块
|
||||
from app.state import MessagesState, GraphContext
|
||||
from app.utils.logging import log_state_change
|
||||
from app.logger import info, error
|
||||
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
async def finalize_node(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]:
|
||||
"""
|
||||
完成事件节点 - 发送完成事件,包含token使用情况和耗时信息
|
||||
|
||||
Args:
|
||||
state: 当前对话状态
|
||||
config: 运行时配置
|
||||
|
||||
Returns:
|
||||
空字典(完成节点,无状态更新)
|
||||
"""
|
||||
log_state_change("finalize", state, "进入")
|
||||
|
||||
try:
|
||||
# 获取流式写入器并发送完成事件
|
||||
writer = get_stream_writer()
|
||||
writer({
|
||||
"type": "custom",
|
||||
"data": {
|
||||
"type": "done",
|
||||
"token_usage": state.get("last_token_usage", {}),
|
||||
"elapsed_time": state.get("last_elapsed_time", 0.0)
|
||||
}
|
||||
})
|
||||
info("🏁 [完成事件] 已发送完成事件,包含token使用情况和耗时信息")
|
||||
except Exception as e:
|
||||
error(f"❌ [完成事件] 发送完成事件时发生异常: {e}")
|
||||
|
||||
log_state_change("finalize", state, "离开")
|
||||
return {}
|
||||
@@ -32,15 +32,19 @@ def create_llm_call_node(llm: BaseLLM, tools: list):
|
||||
# 构建调用链
|
||||
prompt = create_system_prompt()
|
||||
llm_with_tools = llm.bind_tools(tools)
|
||||
chain = prompt | RunnableLambda(print_llm_input) | llm_with_tools
|
||||
|
||||
async def call_llm(state: MessagesState, runtime: Runtime[GraphContext]) -> Dict[str, Any]:
|
||||
# 恢复带 RunnableLambda 的链,并在下方使用 astream 遍历
|
||||
chain = prompt | llm_with_tools
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
async def call_llm(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]:
|
||||
"""
|
||||
LLM 调用节点(异步方法)
|
||||
|
||||
Args:
|
||||
state: 当前对话状态
|
||||
runtime: LangGraph 运行时上下文
|
||||
config: LangChain/LangGraph 自动注入的配置,包含 callbacks 等信息
|
||||
|
||||
Returns:
|
||||
更新后的状态字典
|
||||
@@ -48,17 +52,28 @@ def create_llm_call_node(llm: BaseLLM, tools: list):
|
||||
log_state_change("llm_call", state, "进入")
|
||||
|
||||
memory_context = state.get("memory_context", "暂无用户信息")
|
||||
loop = asyncio.get_event_loop()
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
response = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: chain.invoke({
|
||||
# 恢复为:手动进行 astream,并将所有的 chunk 拼接成最终的 response 返回。
|
||||
# LangGraph 会自动监听这期间产生的所有 token。
|
||||
chunks = []
|
||||
async for chunk in chain.astream(
|
||||
{
|
||||
"messages": state["messages"],
|
||||
"memory_context": memory_context
|
||||
})
|
||||
)
|
||||
},
|
||||
config=config
|
||||
):
|
||||
chunks.append(chunk)
|
||||
|
||||
# 将所有 chunk 合并成最终的 AIMessage
|
||||
if chunks:
|
||||
response = chunks[0]
|
||||
for chunk in chunks[1:]:
|
||||
response = response + chunk
|
||||
else:
|
||||
response = AIMessage(content="")
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
@@ -85,13 +100,7 @@ def create_llm_call_node(llm: BaseLLM, tools: list):
|
||||
if token_usage:
|
||||
input_tokens = token_usage.get('prompt_tokens', token_usage.get('input_tokens', 0))
|
||||
output_tokens = token_usage.get('completion_tokens', token_usage.get('output_tokens', 0))
|
||||
|
||||
# 打印响应统计信息
|
||||
info(f"⏱️ [LLM统计] 调用耗时: {elapsed_time:.2f}秒")
|
||||
info(f"📊 [LLM统计] Token用量: 输入={input_tokens}, 输出={output_tokens}, 总计={input_tokens + output_tokens}")
|
||||
if token_usage:
|
||||
debug(f"📋 [LLM统计] 详细用量: {token_usage}")
|
||||
|
||||
|
||||
# 打印 LLM 的完整输出
|
||||
debug("\n" + "="*80)
|
||||
debug("📥 [LLM输出] 大模型返回的完整响应:")
|
||||
@@ -99,6 +108,12 @@ def create_llm_call_node(llm: BaseLLM, tools: list):
|
||||
debug(f" 内容长度: {len(str(response.content))} 字符")
|
||||
debug("-"*80)
|
||||
debug(f"{response.content}")
|
||||
|
||||
# 打印响应统计信息
|
||||
info(f"⏱️ [LLM统计] 调用耗时: {elapsed_time:.2f}秒")
|
||||
info(f"📊 [LLM统计] Token用量: 输入={input_tokens}, 输出={output_tokens}, 总计={input_tokens + output_tokens}")
|
||||
if token_usage:
|
||||
debug(f"📋 [LLM统计] 详细用量: {token_usage}")
|
||||
debug("="*80 + "\n")
|
||||
|
||||
result = {
|
||||
|
||||
@@ -24,20 +24,23 @@ def create_retrieve_memory_node(mem0_client: Mem0Client):
|
||||
异步节点函数
|
||||
"""
|
||||
|
||||
async def retrieve_memory(state: MessagesState, runtime: Runtime[GraphContext]) -> Dict[str, Any]:
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
async def retrieve_memory(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]:
|
||||
"""
|
||||
记忆检索节点 - 使用 Mem0
|
||||
|
||||
Args:
|
||||
state: 当前对话状态
|
||||
runtime: LangGraph 运行时上下文
|
||||
config: 运行时配置
|
||||
|
||||
Returns:
|
||||
包含 memory_context 的状态更新
|
||||
"""
|
||||
log_state_change("retrieve_memory", state, "进入")
|
||||
|
||||
user_id = runtime.context.user_id
|
||||
# 从 metadata 中获取 user_id
|
||||
user_id = config.get("metadata", {}).get("user_id", "default_user")
|
||||
|
||||
# 兼容 dict 和对象两种消息格式
|
||||
last_msg = state["messages"][-1]
|
||||
|
||||
@@ -12,7 +12,7 @@ from app.state import MessagesState
|
||||
from app.logger import info
|
||||
|
||||
|
||||
def should_continue(state: MessagesState) -> Literal['tool_node', 'summarize', 'END']:
|
||||
def should_continue(state: MessagesState) -> Literal['tool_node', 'summarize', 'finalize']:
|
||||
"""
|
||||
决定下一步:工具调用、生成摘要还是结束
|
||||
|
||||
@@ -20,7 +20,7 @@ def should_continue(state: MessagesState) -> Literal['tool_node', 'summarize', '
|
||||
state: 当前对话状态
|
||||
|
||||
Returns:
|
||||
下一个节点名称或 END
|
||||
下一个节点名称
|
||||
"""
|
||||
last_message = state["messages"][-1]
|
||||
|
||||
@@ -40,9 +40,9 @@ def should_continue(state: MessagesState) -> Literal['tool_node', 'summarize', '
|
||||
else:
|
||||
if ENABLE_GRAPH_TRACE:
|
||||
info(f"🔀 [路由决策] 收到 AI 最终回复,未达摘要阈值({turns}/{MEMORY_SUMMARIZE_INTERVAL}) → 结束流程")
|
||||
return 'END'
|
||||
return 'finalize'
|
||||
|
||||
# 3. 其他情况(如只有用户消息)直接结束
|
||||
if ENABLE_GRAPH_TRACE:
|
||||
info(f"🔀 [路由决策] 非 AI 消息(如纯用户消息) → 结束流程")
|
||||
return 'END'
|
||||
return 'finalize'
|
||||
|
||||
@@ -24,13 +24,15 @@ def create_summarize_node(mem0_client: Mem0Client):
|
||||
异步节点函数
|
||||
"""
|
||||
|
||||
async def summarize_conversation(state: MessagesState, runtime: Runtime[GraphContext]) -> Dict[str, Any]:
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
async def summarize_conversation(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]:
|
||||
"""
|
||||
记忆存储节点 - 使用 Mem0
|
||||
|
||||
Args:
|
||||
state: 当前对话状态
|
||||
runtime: LangGraph 运行时上下文
|
||||
config: 运行时配置
|
||||
|
||||
Returns:
|
||||
重置计数器的状态更新
|
||||
@@ -42,7 +44,8 @@ def create_summarize_node(mem0_client: Mem0Client):
|
||||
debug("📝 [记忆添加] 对话过短,跳过")
|
||||
return {"turns_since_last_summary": 0}
|
||||
|
||||
user_id = runtime.context.user_id
|
||||
# 从 metadata 中获取 user_id
|
||||
user_id = config.get("metadata", {}).get("user_id", "default_user")
|
||||
|
||||
# 确保 Mem0 已初始化(懒加载)
|
||||
if not mem0_client._initialized:
|
||||
@@ -83,4 +86,4 @@ def create_summarize_node(mem0_client: Mem0Client):
|
||||
log_state_change("summarize", state, "离开")
|
||||
return {"turns_since_last_summary": 0}
|
||||
|
||||
return summarize_conversation
|
||||
return summarize_conversation
|
||||
@@ -7,6 +7,7 @@ import asyncio
|
||||
from typing import Any, Dict
|
||||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
from langgraph.runtime import Runtime
|
||||
from langgraph.config import get_stream_writer
|
||||
|
||||
# 本地模块
|
||||
from app.state import MessagesState, GraphContext
|
||||
@@ -25,13 +26,15 @@ def create_tool_call_node(tools_by_name: Dict[str, Any]):
|
||||
异步节点函数
|
||||
"""
|
||||
|
||||
async def call_tools(state: MessagesState, runtime: Runtime[GraphContext]) -> Dict[str, Any]:
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
async def call_tools(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]:
|
||||
"""
|
||||
工具执行节点(异步方法)
|
||||
|
||||
Args:
|
||||
state: 当前对话状态
|
||||
runtime: LangGraph 运行时上下文
|
||||
config: 运行时配置
|
||||
|
||||
Returns:
|
||||
包含 ToolMessage 的状态更新
|
||||
@@ -62,6 +65,10 @@ def create_tool_call_node(tools_by_name: Dict[str, Any]):
|
||||
results.append(ToolMessage(content=err_msg, tool_call_id=tool_id))
|
||||
continue
|
||||
|
||||
# 获取流式写入器并发送工具调用开始事件
|
||||
writer = get_stream_writer()
|
||||
writer({"type": "custom", "data": {"type": "tool_start", "tool": tool_name}})
|
||||
|
||||
try:
|
||||
# 修复闭包问题:将变量作为默认参数传入 lambda
|
||||
# 如果工具支持异步 (ainvoke),优先使用异步调用
|
||||
@@ -77,9 +84,15 @@ def create_tool_call_node(tools_by_name: Dict[str, Any]):
|
||||
result_preview = str(observation).replace("\n", " ")
|
||||
debug(f" └─ ✅ 结果: {result_preview}")
|
||||
results.append(ToolMessage(content=str(observation), tool_call_id=tool_id))
|
||||
|
||||
# 发送工具调用完成事件
|
||||
writer({"type": "custom", "data": {"type": "tool_end", "tool": tool_name, "success": True}})
|
||||
except Exception as e:
|
||||
debug(f" └─ ❌ 异常: {e}")
|
||||
results.append(ToolMessage(content=f"Error: {e}", tool_call_id=tool_id))
|
||||
|
||||
# 发送工具调用失败事件
|
||||
writer({"type": "custom", "data": {"type": "tool_end", "tool": tool_name, "success": False, "error": str(e)}})
|
||||
|
||||
info(f"🛠️ [工具调用] 执行完成,返回 {len(results)} 条 ToolMessage")
|
||||
|
||||
@@ -87,4 +100,4 @@ def create_tool_call_node(tools_by_name: Dict[str, Any]):
|
||||
log_state_change("tool_node", {**state, **result}, "离开")
|
||||
return result
|
||||
|
||||
return call_tools
|
||||
return call_tools
|
||||
@@ -27,9 +27,10 @@ def create_system_prompt() -> ChatPromptTemplate:
|
||||
"- 抓取网页内容:`fetch_webpage_content`\n"
|
||||
"工具调用时请直接返回所需参数,无需额外说明。\n\n"
|
||||
"【回答要求(必须遵守)】\n"
|
||||
"1. 回答必须简洁、直接,禁止描述任何思考过程或内心活动。\n"
|
||||
"2. 优先利用已知用户信息进行个性化回复。\n"
|
||||
"3. 若无信息可依,礼貌询问或提供通用帮助。"
|
||||
"1. 回答必须简洁、直接。\n"
|
||||
"2. 如果你认为该问题需要进行深入的推理或思考,请务必将你的思维链或推理过程用 `<think>` 和 `</think>` 标签包裹起来,放在回答的最前面。例如:<think>这里是我的思考过程...</think>这里是最终回答。\n"
|
||||
"3. 优先利用已知用户信息进行个性化回复。\n"
|
||||
"4. 若无信息可依,礼貌询问或提供通用帮助。"
|
||||
)
|
||||
|
||||
return ChatPromptTemplate.from_messages([
|
||||
|
||||
@@ -25,10 +25,10 @@ def log_state_change(node_name: str, state: dict, prefix: str = "进入"):
|
||||
if last_msg:
|
||||
# 兼容 dict 和对象两种格式
|
||||
if isinstance(last_msg, dict):
|
||||
content_preview = str(last_msg.get("content", ""))[:100].replace("\n", " ")
|
||||
content_preview = str(last_msg.get("content", ""))[:10].replace("\n", " ")
|
||||
msg_type = last_msg.get("type", "unknown")
|
||||
else:
|
||||
content_preview = str(last_msg.content)[:100].replace("\n", " ")
|
||||
content_preview = str(last_msg.content)[:10].replace("\n", " ")
|
||||
msg_type = getattr(last_msg, 'type', 'unknown')
|
||||
last_info = f"{msg_type.upper()}: {content_preview}"
|
||||
info(f"🔄 [{node_name}] {prefix} | 消息数:{msg_count} | 最后一条:{last_info}")
|
||||
|
||||
Reference in New Issue
Block a user