This commit is contained in:
65
app/agent.py
65
app/agent.py
@@ -137,7 +137,10 @@ class AIAgentService:
|
||||
raise RuntimeError(f"错误: 没有任何可用的模型。当前注册的模型: {list(self.graphs.keys())}")
|
||||
|
||||
graph = self.graphs[model]
|
||||
config = {"configurable": {"thread_id": thread_id}}
|
||||
config = {
|
||||
"configurable": {"thread_id": thread_id},
|
||||
"metadata": {"user_id": user_id} # 写入 metadata 供历史查询使用
|
||||
}
|
||||
input_state = {"messages": [{"role": "user", "content": message}]}
|
||||
context = GraphContext(user_id=user_id)
|
||||
|
||||
@@ -152,3 +155,63 @@ class AIAgentService:
|
||||
"token_usage": token_usage,
|
||||
"elapsed_time": elapsed_time
|
||||
}
|
||||
|
||||
async def process_message_stream(self, message: str, thread_id: str, model_name: str, user_id: str = "default_user"):
|
||||
"""
|
||||
流式处理消息,返回异步生成器
|
||||
|
||||
Args:
|
||||
message: 用户消息
|
||||
thread_id: 线程 ID
|
||||
model_name: 模型名称
|
||||
user_id: 用户 ID
|
||||
|
||||
Yields:
|
||||
字典,包含事件类型和数据
|
||||
"""
|
||||
graph = self.graphs.get(model_name)
|
||||
if not graph:
|
||||
warning(f"警告: 模型 '{model_name}' 不可用,使用默认模型")
|
||||
model_name = next(iter(self.graphs.keys()))
|
||||
graph = self.graphs[model_name]
|
||||
|
||||
config = {
|
||||
"configurable": {"thread_id": thread_id},
|
||||
"metadata": {"user_id": user_id}
|
||||
}
|
||||
input_state = {"messages": [{"role": "user", "content": message}]}
|
||||
context = GraphContext(user_id=user_id)
|
||||
|
||||
# 使用 astream_events 获取流式事件
|
||||
async for event in graph.astream_events(input_state, config=config, context=context, version="v2"):
|
||||
kind = event["event"]
|
||||
|
||||
# 聊天模型流式输出
|
||||
if kind == "on_chat_model_stream":
|
||||
content = event["data"]["chunk"].content
|
||||
if content:
|
||||
yield {"type": "token", "content": content}
|
||||
|
||||
# 工具调用开始
|
||||
elif kind == "on_tool_start":
|
||||
tool_name = event["name"]
|
||||
yield {"type": "tool_start", "tool": tool_name}
|
||||
|
||||
# 工具调用结束
|
||||
elif kind == "on_tool_end":
|
||||
tool_name = event["name"]
|
||||
yield {"type": "tool_end", "tool": tool_name}
|
||||
|
||||
# 链结束,获取最终结果
|
||||
elif kind == "on_chain_end" and event["name"] == "LangGraph":
|
||||
output = event["data"]["output"]
|
||||
reply = output["messages"][-1].content if output.get("messages") else ""
|
||||
token_usage = output.get("last_token_usage", {})
|
||||
elapsed_time = output.get("last_elapsed_time", 0.0)
|
||||
|
||||
yield {
|
||||
"type": "done",
|
||||
"reply": reply,
|
||||
"token_usage": token_usage,
|
||||
"elapsed_time": elapsed_time
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user