添加长期存储,流式检查
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Has been cancelled

This commit is contained in:
2026-04-17 01:26:05 +08:00
parent 602d551fd1
commit 404efde282
37 changed files with 794 additions and 2095 deletions

View File

@@ -4,6 +4,7 @@ AI Agent 服务类 - 支持多模型动态切换
"""
import os
import json
from dotenv import load_dotenv
from langchain_community.chat_models import ChatZhipuAI
from langchain_openai import ChatOpenAI
@@ -41,8 +42,9 @@ class AIAgentService:
api_key=api_key,
temperature=0.1,
max_tokens=4096,
timeout=60.0, # 请求超时时间(秒)
max_retries=2, # 失败后自动重试次数
timeout=120.0, # 增加请求超时时间(秒)原为60秒
max_retries=3, # 增加重试次数原为2次
streaming=True, # 确保开启流式输出
)
def _create_deepseek_llm(self):
@@ -58,6 +60,7 @@ class AIAgentService:
max_tokens=4096,
timeout=60.0, # 请求超时时间(秒)
max_retries=2, # 失败后自动重试次数
streaming=True, # 确保开启流式输出
)
def _create_local_llm(self):
@@ -65,7 +68,7 @@ class AIAgentService:
# vLLM 服务地址:优先从环境变量读取,适配 Docker、FRP 穿透和本地开发
vllm_base_url = os.getenv(
"VLLM_BASE_URL",
"http://localhost:8081/v1"
"http://127.0.0.1:8081/v1"
)
return ChatOpenAI(
@@ -74,14 +77,15 @@ class AIAgentService:
model="gemma-4-E2B-it",
timeout=60.0, # 请求超时时间(秒)
max_retries=2, # 失败后自动重试次数
streaming=True, # 确保开启流式输出
)
async def initialize(self):
"""预编译所有模型的 graph使用传入的 checkpointer"""
model_configs = {
"zhipu": self._create_zhipu_llm,
"deepseek": self._create_deepseek_llm,
"local": self._create_local_llm,
"local": self._create_local_llm, # 本地模型作为第一个
"deepseek": self._create_deepseek_llm, # DeepSeek 作为中间
"zhipu": self._create_zhipu_llm, # GLM-4.7 作为最后一个
}
for model_name, llm_creator in model_configs.items():
@@ -107,7 +111,7 @@ class AIAgentService:
return self
async def process_message(self, message: str, thread_id: str, model: str = "zhipu", user_id: str = "default_user") -> dict:
async def process_message(self, message: str, thread_id: str, model: str = "local", user_id: str = "default_user") -> dict:
"""
处理用户消息返回包含回复、token统计和耗时的字典
@@ -156,6 +160,28 @@ class AIAgentService:
"elapsed_time": elapsed_time
}
def _serialize_value(self, value):
"""递归将 LangChain 对象转换为可 JSON 序列化的格式"""
if hasattr(value, 'content'):
# LangChain 消息对象
msg_type = getattr(value, 'type', 'message')
return {
"role": msg_type,
"content": getattr(value, 'content', ''),
"additional_kwargs": getattr(value, 'additional_kwargs', {}),
"tool_calls": getattr(value, 'tool_calls', [])
}
elif isinstance(value, dict):
return {k: self._serialize_value(v) for k, v in value.items()}
elif isinstance(value, (list, tuple)):
return [self._serialize_value(item) for item in value]
else:
try:
json.dumps(value)
return value
except (TypeError, ValueError):
return str(value)
async def process_message_stream(self, message: str, thread_id: str, model_name: str, user_id: str = "default_user"):
"""
流式处理消息,返回异步生成器
@@ -170,10 +196,9 @@ class AIAgentService:
字典,包含事件类型和数据
"""
graph = self.graphs.get(model_name)
if not graph:
warning(f"警告: 模型 '{model_name}' 不可用,使用默认模型")
model_name = next(iter(self.graphs.keys()))
graph = self.graphs[model_name]
raise ValueError(f"模型 '{model_name}' 未找到或未初始化")
config = {
"configurable": {"thread_id": thread_id},
@@ -182,36 +207,71 @@ class AIAgentService:
input_state = {"messages": [{"role": "user", "content": message}]}
context = GraphContext(user_id=user_id)
# 使用 astream_events 获取流式事件
async for event in graph.astream_events(input_state, config=config, context=context, version="v2"):
kind = event["event"]
# 聊天模型流式输出
if kind == "on_chat_model_stream":
content = event["data"]["chunk"].content
if content:
yield {"type": "token", "content": content}
# 工具调用开始
elif kind == "on_tool_start":
tool_name = event["name"]
yield {"type": "tool_start", "tool": tool_name}
# 工具调用结束
elif kind == "on_tool_end":
tool_name = event["name"]
yield {"type": "tool_end", "tool": tool_name}
# 链结束,获取最终结果
elif kind == "on_chain_end" and event["name"] == "LangGraph":
output = event["data"]["output"]
reply = output["messages"][-1].content if output.get("messages") else ""
token_usage = output.get("last_token_usage", {})
elapsed_time = output.get("last_elapsed_time", 0.0)
async for chunk in graph.astream(
input_state,
config=config,
context=context,
stream_mode=["messages", "updates", "custom"], # 组合多种模式,添加 custom
version="v2", # 使用统一的v2格式
subgraphs=True # 如果你使用了子图,请开启此项
):
chunk_type = chunk["type"]
processed_event = {}
# 1. 处理 LLM Token 流 (实现打字机效果)
if chunk_type == "messages":
message_chunk, metadata = chunk["data"]
yield {
"type": "done",
"reply": reply,
"token_usage": token_usage,
"elapsed_time": elapsed_time
# 提取元数据
node_name = metadata.get("langgraph_node", "unknown")
# 使用 getattr 安全地获取内容,因为 message_chunk 可能不是字符串
token_content = getattr(message_chunk, 'content', str(message_chunk))
# 提取 DeepSeek reasoner 的思考过程 token
reasoning_token = ""
if hasattr(message_chunk, 'additional_kwargs'):
reasoning_token = message_chunk.additional_kwargs.get("reasoning_content", "")
# [DEBUG] 临时添加:只在 reasoning_token 不为空时打印,方便你直观地看到它
if reasoning_token:
import logging
logging.debug(f"💡 [Reasoning Token 捕获]: {repr(reasoning_token)}")
processed_event = {
"type": "llm_token",
"node": node_name,
"token": token_content,
"reasoning_token": reasoning_token,
"metadata": metadata # 可选的元数据
}
# 2. 处理状态更新 (节点执行完成)
elif chunk_type == "updates":
updates_data = chunk["data"]
# 序列化 updates 中的所有数据
serialized_data = self._serialize_value(updates_data)
processed_event = {
"type": "state_update",
"data": serialized_data
}
# 为了兼容前端旧字段,也保留 messages 字段(可选)
if "messages" in serialized_data:
processed_event["messages"] = serialized_data["messages"]
# 3. 处理自定义数据 (如果需要)
elif chunk_type == "custom":
# 自定义事件同样需要序列化
serialized_data = self._serialize_value(chunk["data"])
processed_event = {
"type": "custom",
"data": serialized_data
}
# 4. 其他类型debug, tasks等按需处理
else:
# 对于不需要的类型,直接跳过
continue
# 确保事件有数据再发送
if processed_event:
yield processed_event

View File

@@ -25,7 +25,7 @@ load_dotenv()
# 优先级:环境变量 DB_URI > Docker 内部服务名 > 本地开发地址
DB_URI = os.getenv(
"DB_URI",
"postgresql://postgres:mysecretpassword@ai-postgres:5432/langgraph_db?sslmode=disable"
"postgresql://postgres:huang1998@ai-postgres:5432/langgraph_db?sslmode=disable"
)

View File

@@ -16,8 +16,9 @@ MEMORY_SUMMARIZE_INTERVAL = int(os.getenv("MEMORY_SUMMARIZE_INTERVAL", "10"))
# ========== Mem0 记忆层配置 ==========
# Qdrant 向量数据库地址
QDRANT_URL = os.getenv("QDRANT_URL", "http://localhost:6333")
QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333")
QDRANT_COLLECTION_NAME = os.getenv("QDRANT_COLLECTION_NAME", "mem0_user_memories")
# vLLM Embedding 服务地址 (用于 Mem0 的向量化)
VLLM_EMBEDDING_URL = os.getenv("VLLM_EMBEDDING_URL", "http://localhost:8082/v1")
# llama.cpp Embedding 服务地址 (用于 Mem0 的向量化)
LLAMACPP_EMBEDDING_URL = os.getenv("LLAMACPP_EMBEDDING_URL", "http://127.0.0.1:8082/v1")
LLAMACPP_API_KEY = os.getenv("LLAMACPP_API_KEY", "your-llamacpp-api-key")

View File

@@ -16,6 +16,7 @@ from app.nodes import (
should_continue
)
from app.memory import Mem0Client
from app.nodes.finalize import finalize_node
class GraphBuilder:
@@ -57,6 +58,7 @@ class GraphBuilder:
builder.add_node("llm_call", llm_call_node)
builder.add_node("tool_node", tool_call_node)
builder.add_node("summarize", summarize_node)
builder.add_node("finalize", finalize_node)
# 添加边
builder.add_edge(START, "retrieve_memory")
@@ -67,10 +69,11 @@ class GraphBuilder:
{
"tool_node": "tool_node",
"summarize": "summarize",
'END': END
"finalize": "finalize"
}
)
builder.add_edge("tool_node", "llm_call")
builder.add_edge("summarize", END)
builder.add_edge("summarize", "finalize")
builder.add_edge("finalize", END)
return builder
return builder

View File

@@ -28,11 +28,14 @@ class ThreadHistoryService:
try:
# 查询 checkpoints 表获取用户的线程列表
async with self.checkpointer.conn.cursor() as cur:
# 查询每个线程的最新 checkpoint 和创建时间
# 在较新的 LangGraph 版本中AsyncPostgresSaver 创建的 checkpoints 表
# 没有 created_at 列,而是使用 checkpoint_id 作为时间排序依据。
# 我们可以直接按 thread_id 去重,并用 checkpoint_id 排序。
# 另外,用户的 metadata 存储在 metadata JSONB 列中。
query = """
SELECT
thread_id,
MAX(created_at) as last_updated
MAX(checkpoint_id) as last_updated
FROM checkpoints
WHERE metadata->>'user_id' = %s
GROUP BY thread_id
@@ -49,17 +52,20 @@ class ThreadHistoryService:
# 获取该线程的状态
state = await self.checkpointer.aget_tuple({"configurable": {"thread_id": thread_id}})
if state and state.values:
messages = state.values.get("messages", [])
summary = self._extract_summary(messages)
message_count = len([m for m in messages if hasattr(m, 'type') and m.type in ["human", "ai"]])
if state and hasattr(state, 'checkpoint') and isinstance(state.checkpoint, dict):
messages = state.checkpoint.get("channel_values", {}).get("messages", [])
threads.append({
"thread_id": thread_id,
"last_updated": row['last_updated'].isoformat() if row['last_updated'] else "",
"summary": summary,
"message_count": message_count
})
if messages:
summary = self._extract_summary(messages)
message_count = len([m for m in messages if hasattr(m, 'type') and m.type in ["human", "ai"]])
threads.append({
"thread_id": thread_id,
# checkpoint_id 是一个类似于 uuid 的字符串,其中可能包含时间戳信息,也可以直接作为唯一标识
"last_updated": row['last_updated'] if row['last_updated'] else "",
"summary": summary,
"message_count": message_count
})
return threads
@@ -80,10 +86,13 @@ class ThreadHistoryService:
try:
state = await self.checkpointer.aget_tuple({"configurable": {"thread_id": thread_id}})
if state is None or not state.values:
if state is None:
return []
messages = state.values.get("messages", [])
messages = state.checkpoint.get("channel_values", {}).get("messages", []) if hasattr(state, 'checkpoint') and isinstance(state.checkpoint, dict) else []
if not messages:
return []
# 转换 LangChain 消息对象为字典
result = []

View File

@@ -3,142 +3,151 @@ Mem0 记忆层客户端封装模块
负责 Mem0 的初始化、检索和存储
"""
import os
import asyncio
from typing import Optional, List, Dict, Any
from mem0 import AsyncMemory
# 本地模块
from app.config import QDRANT_URL, QDRANT_COLLECTION_NAME, VLLM_EMBEDDING_URL
from app.config import QDRANT_URL, QDRANT_COLLECTION_NAME, LLAMACPP_EMBEDDING_URL, LLAMACPP_API_KEY
from app.logger import info, warning, error
class Mem0Client:
"""Mem0 异步客户端封装类"""
def __init__(self, llm_instance):
"""
初始化 Mem0 客户端
Args:
llm_instance: LangChain LLM 实例(用于事实提取)
"""
self.llm = llm_instance
self.mem0: Optional[AsyncMemory] = None
self._initialized = False
async def initialize(self):
"""异步初始化 Mem0 客户端"""
"""异步初始化 Mem0 客户端,并进行实际连接测试"""
if self._initialized:
return
try:
# 检查 Qdrant 是否可达 (可选)
import requests
try:
resp = requests.get(f"{QDRANT_URL}/collections", timeout=2)
if resp.status_code == 200:
info(f"✅ Qdrant 服务正常: {QDRANT_URL}")
except Exception:
warning(f"⚠️ 无法连接到 Qdrant: {QDRANT_URL}Mem0 将尝试自动连接")
try:
# Mem0 配置
config = {
# 向量存储:复用 Qdrant 实例
"vector_store": {
"provider": "qdrant",
"config": {
"url": QDRANT_URL, # 直接使用完整 URL
"collection_name": QDRANT_COLLECTION_NAME,
"host": QDRANT_URL.split("://")[1].split(":")[0] if "://" in QDRANT_URL else "localhost",
"port": int(QDRANT_URL.split(":")[-1]) if ":" in QDRANT_URL.split("://")[-1] else 6333,
"embedding_model_dims": 768, # embeddinggemma-300m 输出 768 维
"embedding_model_dims": 768,
}
},
# 事实提取 LLM直接复用传入的 LangChain 实例
"llm": {
"provider": "langchain",
"config": {
"model": self.llm # 直接传入 LangChain 模型实例
"model": self.llm
}
},
# Embedding指向 vLLM 服务
"embedder": {
"provider": "openai",
"embedding_dims": 768, # 关键:将维度参数提升到顶层
"config": {
"model": "google/embeddinggemma-300m",
"api_key": "EMPTY",
"api_base": VLLM_EMBEDDING_URL,
# 注意:不要在此处传递 dimensions 参数,避免与 vLLM v0.7.2 不兼容
}
"model": "embeddinggemma-300M-Q8_0",
"api_key": LLAMACPP_API_KEY,
"openai_base_url": LLAMACPP_EMBEDDING_URL,
},
},
"version": "v1.1"
}
self.mem0 = AsyncMemory.from_config(config)
self._initialized = True
info(f"✅ Mem0 初始化成功 (Embedding: vLLM@8002, Vector: Qdrant, LLM: 复用现有实例)")
info("✅ Mem0 配置加载成功,开始连接测试...")
except Exception as e:
error(f"❌ Mem0 初始化失败: {e}")
import traceback
traceback.print_exc()
# 实际连接测试:调用一次 search 确保 Qdrant 和 Embedding 都可达
await asyncio.wait_for(
self.mem0.search("ping", user_id="test", limit=1),
timeout=60.0
)
info("✅ Mem0 实际连接测试成功,初始化完成")
self._initialized = True
except asyncio.TimeoutError:
error("❌ Mem0 连接测试超时 (10s),请检查 Qdrant 或 Embedding 服务响应")
self.mem0 = None
self._initialized = False
except Exception as e:
error(f"❌ Mem0 初始化或连接测试失败: {e}")
import traceback
error(f"详细错误信息:\n{traceback.format_exc()}")
self.mem0 = None
self._initialized = False
async def search_memories(self, query: str, user_id: str, limit: int = 5) -> List[str]:
"""
检索相关记忆
Args:
query: 查询文本
user_id: 用户 ID
limit: 返回结果数量限制
Returns:
List[str]: 记忆事实列表
"""
if not self.mem0:
warning("⚠️ Mem0 未初始化,跳过记忆检索")
return []
try:
memories = await self.mem0.search(query, user_id=user_id, limit=limit)
memories = await asyncio.wait_for(
self.mem0.search(query, user_id=user_id, limit=limit),
timeout=30.0
)
if memories and "results" in memories:
facts = [m["memory"] for m in memories["results"] if m.get("memory")]
if facts:
info(f"🔍 [记忆检索] Mem0 返回 {len(facts)} 条记忆")
return facts
info("🔍 [记忆检索] 未找到相关记忆")
return []
except asyncio.TimeoutError:
warning("⚠️ Mem0 检索超时 (30s),跳过本次记忆检索")
return []
except Exception as e:
warning(f"⚠️ Mem0 检索失败: {e}")
return []
async def add_memories(self, messages: List[Dict[str, str]], user_id: str) -> bool:
"""
添加记忆(自动提取事实并存储)
Args:
messages: 消息列表,格式为 [{"role": "user/assistant/system", "content": "..."}]
user_id: 用户 ID
Returns:
bool: 是否成功
"""
if not self.mem0:
warning("⚠️ Mem0 未初始化,跳过记忆添加")
return False
try:
result = await self.mem0.add(
messages,
user_id=user_id,
metadata={"type": "conversation"}
await asyncio.wait_for(
self.mem0.add(
messages,
user_id=user_id,
metadata={"type": "conversation"}
),
timeout=60.0
)
info(f"📝 [记忆添加] 已提交给 Mem0 进行事实提取")
info("📝 [记忆添加] 已提交给 Mem0 进行事实提取")
return True
except asyncio.TimeoutError:
error("❌ Mem0 记忆添加超时 (60s)")
return False
except Exception as e:
error(f"❌ Mem0 记忆添加失败: {e}")
return False
return False

View File

@@ -7,6 +7,7 @@ from app.nodes.llm_call import create_llm_call_node
from app.nodes.tool_call import create_tool_call_node
from app.nodes.retrieve_memory import create_retrieve_memory_node
from app.nodes.summarize import create_summarize_node
from app.nodes.finalize import finalize_node
__all__ = [
"should_continue",
@@ -14,4 +15,5 @@ __all__ = [
"create_tool_call_node",
"create_retrieve_memory_node",
"create_summarize_node",
"finalize_node",
]

47
app/nodes/finalize.py Normal file
View File

@@ -0,0 +1,47 @@
"""
完成事件节点模块
负责发送完成事件包含token使用情况和耗时信息
"""
from typing import Any, Dict
from langgraph.runtime import Runtime
from langgraph.config import get_stream_writer
# 本地模块
from app.state import MessagesState, GraphContext
from app.utils.logging import log_state_change
from app.logger import info, error
from langchain_core.runnables.config import RunnableConfig
async def finalize_node(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]:
"""
完成事件节点 - 发送完成事件包含token使用情况和耗时信息
Args:
state: 当前对话状态
config: 运行时配置
Returns:
空字典(完成节点,无状态更新)
"""
log_state_change("finalize", state, "进入")
try:
# 获取流式写入器并发送完成事件
writer = get_stream_writer()
writer({
"type": "custom",
"data": {
"type": "done",
"token_usage": state.get("last_token_usage", {}),
"elapsed_time": state.get("last_elapsed_time", 0.0)
}
})
info("🏁 [完成事件] 已发送完成事件包含token使用情况和耗时信息")
except Exception as e:
error(f"❌ [完成事件] 发送完成事件时发生异常: {e}")
log_state_change("finalize", state, "离开")
return {}

View File

@@ -32,15 +32,19 @@ def create_llm_call_node(llm: BaseLLM, tools: list):
# 构建调用链
prompt = create_system_prompt()
llm_with_tools = llm.bind_tools(tools)
chain = prompt | RunnableLambda(print_llm_input) | llm_with_tools
async def call_llm(state: MessagesState, runtime: Runtime[GraphContext]) -> Dict[str, Any]:
# 恢复带 RunnableLambda 的链,并在下方使用 astream 遍历
chain = prompt | llm_with_tools
from langchain_core.runnables.config import RunnableConfig
async def call_llm(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]:
"""
LLM 调用节点(异步方法)
Args:
state: 当前对话状态
runtime: LangGraph 运行时上下文
config: LangChain/LangGraph 自动注入的配置,包含 callbacks 等信息
Returns:
更新后的状态字典
@@ -48,17 +52,28 @@ def create_llm_call_node(llm: BaseLLM, tools: list):
log_state_change("llm_call", state, "进入")
memory_context = state.get("memory_context", "暂无用户信息")
loop = asyncio.get_event_loop()
start_time = time.time()
try:
response = await loop.run_in_executor(
None,
lambda: chain.invoke({
# 恢复为:手动进行 astream并将所有的 chunk 拼接成最终的 response 返回。
# LangGraph 会自动监听这期间产生的所有 token。
chunks = []
async for chunk in chain.astream(
{
"messages": state["messages"],
"memory_context": memory_context
})
)
},
config=config
):
chunks.append(chunk)
# 将所有 chunk 合并成最终的 AIMessage
if chunks:
response = chunks[0]
for chunk in chunks[1:]:
response = response + chunk
else:
response = AIMessage(content="")
elapsed_time = time.time() - start_time
@@ -85,13 +100,7 @@ def create_llm_call_node(llm: BaseLLM, tools: list):
if token_usage:
input_tokens = token_usage.get('prompt_tokens', token_usage.get('input_tokens', 0))
output_tokens = token_usage.get('completion_tokens', token_usage.get('output_tokens', 0))
# 打印响应统计信息
info(f"⏱️ [LLM统计] 调用耗时: {elapsed_time:.2f}")
info(f"📊 [LLM统计] Token用量: 输入={input_tokens}, 输出={output_tokens}, 总计={input_tokens + output_tokens}")
if token_usage:
debug(f"📋 [LLM统计] 详细用量: {token_usage}")
# 打印 LLM 的完整输出
debug("\n" + "="*80)
debug("📥 [LLM输出] 大模型返回的完整响应:")
@@ -99,6 +108,12 @@ def create_llm_call_node(llm: BaseLLM, tools: list):
debug(f" 内容长度: {len(str(response.content))} 字符")
debug("-"*80)
debug(f"{response.content}")
# 打印响应统计信息
info(f"⏱️ [LLM统计] 调用耗时: {elapsed_time:.2f}")
info(f"📊 [LLM统计] Token用量: 输入={input_tokens}, 输出={output_tokens}, 总计={input_tokens + output_tokens}")
if token_usage:
debug(f"📋 [LLM统计] 详细用量: {token_usage}")
debug("="*80 + "\n")
result = {

View File

@@ -24,20 +24,23 @@ def create_retrieve_memory_node(mem0_client: Mem0Client):
异步节点函数
"""
async def retrieve_memory(state: MessagesState, runtime: Runtime[GraphContext]) -> Dict[str, Any]:
from langchain_core.runnables.config import RunnableConfig
async def retrieve_memory(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]:
"""
记忆检索节点 - 使用 Mem0
Args:
state: 当前对话状态
runtime: LangGraph 运行时上下文
config: 运行时配置
Returns:
包含 memory_context 的状态更新
"""
log_state_change("retrieve_memory", state, "进入")
user_id = runtime.context.user_id
# 从 metadata 中获取 user_id
user_id = config.get("metadata", {}).get("user_id", "default_user")
# 兼容 dict 和对象两种消息格式
last_msg = state["messages"][-1]

View File

@@ -12,7 +12,7 @@ from app.state import MessagesState
from app.logger import info
def should_continue(state: MessagesState) -> Literal['tool_node', 'summarize', 'END']:
def should_continue(state: MessagesState) -> Literal['tool_node', 'summarize', 'finalize']:
"""
决定下一步:工具调用、生成摘要还是结束
@@ -20,7 +20,7 @@ def should_continue(state: MessagesState) -> Literal['tool_node', 'summarize', '
state: 当前对话状态
Returns:
下一个节点名称或 END
下一个节点名称
"""
last_message = state["messages"][-1]
@@ -40,9 +40,9 @@ def should_continue(state: MessagesState) -> Literal['tool_node', 'summarize', '
else:
if ENABLE_GRAPH_TRACE:
info(f"🔀 [路由决策] 收到 AI 最终回复,未达摘要阈值({turns}/{MEMORY_SUMMARIZE_INTERVAL}) → 结束流程")
return 'END'
return 'finalize'
# 3. 其他情况(如只有用户消息)直接结束
if ENABLE_GRAPH_TRACE:
info(f"🔀 [路由决策] 非 AI 消息(如纯用户消息) → 结束流程")
return 'END'
return 'finalize'

View File

@@ -24,13 +24,15 @@ def create_summarize_node(mem0_client: Mem0Client):
异步节点函数
"""
async def summarize_conversation(state: MessagesState, runtime: Runtime[GraphContext]) -> Dict[str, Any]:
from langchain_core.runnables.config import RunnableConfig
async def summarize_conversation(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]:
"""
记忆存储节点 - 使用 Mem0
Args:
state: 当前对话状态
runtime: LangGraph 运行时上下文
config: 运行时配置
Returns:
重置计数器的状态更新
@@ -42,7 +44,8 @@ def create_summarize_node(mem0_client: Mem0Client):
debug("📝 [记忆添加] 对话过短,跳过")
return {"turns_since_last_summary": 0}
user_id = runtime.context.user_id
# 从 metadata 中获取 user_id
user_id = config.get("metadata", {}).get("user_id", "default_user")
# 确保 Mem0 已初始化(懒加载)
if not mem0_client._initialized:
@@ -83,4 +86,4 @@ def create_summarize_node(mem0_client: Mem0Client):
log_state_change("summarize", state, "离开")
return {"turns_since_last_summary": 0}
return summarize_conversation
return summarize_conversation

View File

@@ -7,6 +7,7 @@ import asyncio
from typing import Any, Dict
from langchain_core.messages import AIMessage, ToolMessage
from langgraph.runtime import Runtime
from langgraph.config import get_stream_writer
# 本地模块
from app.state import MessagesState, GraphContext
@@ -25,13 +26,15 @@ def create_tool_call_node(tools_by_name: Dict[str, Any]):
异步节点函数
"""
async def call_tools(state: MessagesState, runtime: Runtime[GraphContext]) -> Dict[str, Any]:
from langchain_core.runnables.config import RunnableConfig
async def call_tools(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]:
"""
工具执行节点(异步方法)
Args:
state: 当前对话状态
runtime: LangGraph 运行时上下文
config: 运行时配置
Returns:
包含 ToolMessage 的状态更新
@@ -62,6 +65,10 @@ def create_tool_call_node(tools_by_name: Dict[str, Any]):
results.append(ToolMessage(content=err_msg, tool_call_id=tool_id))
continue
# 获取流式写入器并发送工具调用开始事件
writer = get_stream_writer()
writer({"type": "custom", "data": {"type": "tool_start", "tool": tool_name}})
try:
# 修复闭包问题:将变量作为默认参数传入 lambda
# 如果工具支持异步 (ainvoke),优先使用异步调用
@@ -77,9 +84,15 @@ def create_tool_call_node(tools_by_name: Dict[str, Any]):
result_preview = str(observation).replace("\n", " ")
debug(f" └─ ✅ 结果: {result_preview}")
results.append(ToolMessage(content=str(observation), tool_call_id=tool_id))
# 发送工具调用完成事件
writer({"type": "custom", "data": {"type": "tool_end", "tool": tool_name, "success": True}})
except Exception as e:
debug(f" └─ ❌ 异常: {e}")
results.append(ToolMessage(content=f"Error: {e}", tool_call_id=tool_id))
# 发送工具调用失败事件
writer({"type": "custom", "data": {"type": "tool_end", "tool": tool_name, "success": False, "error": str(e)}})
info(f"🛠️ [工具调用] 执行完成,返回 {len(results)} 条 ToolMessage")
@@ -87,4 +100,4 @@ def create_tool_call_node(tools_by_name: Dict[str, Any]):
log_state_change("tool_node", {**state, **result}, "离开")
return result
return call_tools
return call_tools

View File

@@ -27,9 +27,10 @@ def create_system_prompt() -> ChatPromptTemplate:
"- 抓取网页内容:`fetch_webpage_content`\n"
"工具调用时请直接返回所需参数,无需额外说明。\n\n"
"【回答要求(必须遵守)】\n"
"1. 回答必须简洁、直接,禁止描述任何思考过程或内心活动\n"
"2. 优先利用已知用户信息进行个性化回复\n"
"3. 若无信息可依,礼貌询问或提供通用帮助。"
"1. 回答必须简洁、直接。\n"
"2. 如果你认为该问题需要进行深入的推理或思考,请务必将你的思维链或推理过程用 `<think>` 和 `</think>` 标签包裹起来,放在回答的最前面。例如:<think>这里是我的思考过程...</think>这里是最终回答\n"
"3. 优先利用已知用户信息进行个性化回复。\n"
"4. 若无信息可依,礼貌询问或提供通用帮助。"
)
return ChatPromptTemplate.from_messages([

View File

@@ -25,10 +25,10 @@ def log_state_change(node_name: str, state: dict, prefix: str = "进入"):
if last_msg:
# 兼容 dict 和对象两种格式
if isinstance(last_msg, dict):
content_preview = str(last_msg.get("content", ""))[:100].replace("\n", " ")
content_preview = str(last_msg.get("content", ""))[:10].replace("\n", " ")
msg_type = last_msg.get("type", "unknown")
else:
content_preview = str(last_msg.content)[:100].replace("\n", " ")
content_preview = str(last_msg.content)[:10].replace("\n", " ")
msg_type = getattr(last_msg, 'type', 'unknown')
last_info = f"{msg_type.upper()}: {content_preview}"
info(f"🔄 [{node_name}] {prefix} | 消息数:{msg_count} | 最后一条:{last_info}")