This commit is contained in:
94
app/agent.py
94
app/agent.py
@@ -11,8 +11,11 @@ from langchain_openai import ChatOpenAI
|
||||
from pydantic import SecretStr
|
||||
|
||||
# 本地模块
|
||||
from app.graph_builder import GraphBuilder
|
||||
from app.graph_builder import GraphBuilder, GraphContext
|
||||
from app.tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
|
||||
from app.logger import debug, info, warning, error
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
from langgraph.store.postgres.aio import AsyncPostgresStore
|
||||
|
||||
load_dotenv()
|
||||
|
||||
@@ -20,13 +23,15 @@ load_dotenv()
|
||||
class AIAgentService:
|
||||
"""异步 AI Agent 服务,支持多模型动态切换,使用外部传入的 checkpointer"""
|
||||
|
||||
def __init__(self, checkpointer):
|
||||
def __init__(self, checkpointer: AsyncPostgresSaver, store: AsyncPostgresStore):
|
||||
"""
|
||||
初始化服务
|
||||
Args:
|
||||
checkpointer: 已经初始化的 AsyncPostgresSaver 实例
|
||||
store: 已经初始化的 AsyncPostgresStore 实例
|
||||
"""
|
||||
self.checkpointer = checkpointer
|
||||
self.store = store
|
||||
self.graphs = {} # 存储不同模型对应的 graph 实例
|
||||
|
||||
def _create_zhipu_llm(self):
|
||||
@@ -39,49 +44,108 @@ class AIAgentService:
|
||||
api_key=api_key,
|
||||
temperature=0.1,
|
||||
max_tokens=4096,
|
||||
timeout=60.0, # 请求超时时间(秒)
|
||||
max_retries=2, # 失败后自动重试次数
|
||||
)
|
||||
|
||||
def _create_deepseek_llm(self):
|
||||
"""创建 DeepSeek LLM(使用 OpenAI 兼容 API)"""
|
||||
api_key = os.getenv("DEEPSEEK_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError("DEEPSEEK_API_KEY not set in environment")
|
||||
return ChatOpenAI(
|
||||
base_url="https://api.deepseek.com",
|
||||
api_key=SecretStr(api_key),
|
||||
model="deepseek-reasoner", # deepseek-chat: 非思考模式, deepseek-reasoner: 思考模式
|
||||
temperature=0.1,
|
||||
max_tokens=4096,
|
||||
timeout=60.0, # 请求超时时间(秒)
|
||||
max_retries=2, # 失败后自动重试次数
|
||||
)
|
||||
|
||||
def _create_local_llm(self):
|
||||
"""创建本地 vLLM 服务 LLM"""
|
||||
# vLLM 服务地址:优先从环境变量读取,适配 Docker、FRP 穿透和本地开发
|
||||
vllm_base_url = os.getenv(
|
||||
"VLLM_BASE_URL",
|
||||
"http://115.190.121.151:18000/v1"
|
||||
)
|
||||
|
||||
return ChatOpenAI(
|
||||
# 原来是 http://localhost:8000/v1
|
||||
# 改为 FRP 穿透后的公网地址
|
||||
base_url = "http://115.190.121.151:18000/v1",
|
||||
base_url=vllm_base_url,
|
||||
api_key=SecretStr(os.getenv("VLLM_LOCAL_KEY", "")),
|
||||
model="gemma-4-E2B-it",
|
||||
timeout=60.0, # 请求超时时间(秒)
|
||||
max_retries=2, # 失败后自动重试次数
|
||||
)
|
||||
|
||||
async def initialize(self):
|
||||
"""预编译所有模型的 graph(使用传入的 checkpointer)"""
|
||||
"""预编译所有模型的 graph(使用传入的 checkpointer 和 store)"""
|
||||
model_configs = {
|
||||
"zhipu": self._create_zhipu_llm,
|
||||
"deepseek": self._create_deepseek_llm,
|
||||
"local": self._create_local_llm,
|
||||
}
|
||||
|
||||
for model_name, llm_creator in model_configs.items():
|
||||
try:
|
||||
info(f"🔄 正在初始化模型 '{model_name}'...")
|
||||
llm = llm_creator()
|
||||
|
||||
# 测试 LLM 连接(可选,用于调试)
|
||||
if model_name == "local":
|
||||
debug(f" 测试 vLLM 连接: {os.getenv('VLLM_BASE_URL', '未设置')}")
|
||||
elif model_name == "deepseek":
|
||||
debug(f" 测试 DeepSeek API 连接: https://api.deepseek.com")
|
||||
|
||||
builder = GraphBuilder(llm, AVAILABLE_TOOLS, TOOLS_BY_NAME).build()
|
||||
graph = builder.compile(checkpointer=self.checkpointer)
|
||||
graph = builder.compile(checkpointer=self.checkpointer, store=self.store)
|
||||
self.graphs[model_name] = graph
|
||||
print(f"✅ 模型 '{model_name}' 初始化成功")
|
||||
info(f"✅ 模型 '{model_name}' 初始化成功")
|
||||
except Exception as e:
|
||||
print(f"⚠️ 模型 '{model_name}' 初始化失败: {e}")
|
||||
import traceback
|
||||
error_detail = traceback.format_exc()
|
||||
warning(f"⚠️ 模型 '{model_name}' 初始化失败: {e}")
|
||||
debug(f" 详细错误:\n{error_detail}")
|
||||
|
||||
if not self.graphs:
|
||||
raise RuntimeError("没有可用的模型,请检查配置")
|
||||
raise RuntimeError("没有可用的模型,请检查配置。可能的原因:\n"
|
||||
"1. ZHIPUAI_API_KEY 未配置或无效\n"
|
||||
"2. DEEPSEEK_API_KEY 未配置或无效\n"
|
||||
"3. vLLM 服务未启动或地址错误 (VLLM_BASE_URL)\n"
|
||||
"4. 网络连接问题")
|
||||
|
||||
return self
|
||||
|
||||
async def process_message(self, message: str, thread_id: str, model: str = "zhipu") -> str:
|
||||
"""处理用户消息,返回最终答案"""
|
||||
async def process_message(self, message: str, thread_id: str, model: str = "zhipu", user_id: str = "default_user") -> dict:
|
||||
"""
|
||||
处理用户消息,返回包含回复、token统计和耗时的字典
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
"reply": str, # AI 回复内容
|
||||
"token_usage": dict, # Token 使用详情
|
||||
"elapsed_time": float # 调用耗时(秒)
|
||||
}
|
||||
"""
|
||||
if model not in self.graphs:
|
||||
fallback_model = next(iter(self.graphs.keys()))
|
||||
print(f"警告: 模型 '{model}' 不可用,已切换到 '{fallback_model}'")
|
||||
warning(f"警告: 模型 '{model}' 不可用,已切换到 '{fallback_model}'")
|
||||
model = fallback_model
|
||||
|
||||
graph = self.graphs[model]
|
||||
config = {"configurable": {"thread_id": thread_id}}
|
||||
input_state = {"messages": [HumanMessage(content=message)]}
|
||||
result = await graph.ainvoke(input_state, config=config)
|
||||
return result["messages"][-1].content
|
||||
context = GraphContext(user_id=user_id)
|
||||
|
||||
result = await graph.ainvoke(input_state, config=config, context=context)
|
||||
|
||||
reply = result["messages"][-1].content
|
||||
token_usage = result.get("last_token_usage", {})
|
||||
elapsed_time = result.get("last_elapsed_time", 0.0)
|
||||
|
||||
return {
|
||||
"reply": reply,
|
||||
"token_usage": token_usage,
|
||||
"elapsed_time": elapsed_time
|
||||
}
|
||||
|
||||
@@ -7,17 +7,23 @@ import os
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect, Depends, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
|
||||
from langgraph.store.postgres.aio import AsyncPostgresStore
|
||||
from app.agent import AIAgentService
|
||||
from app.logger import debug, info, warning, error
|
||||
|
||||
# PostgreSQL 连接字符串(优先从环境变量读取,适配 Docker 和本地开发)
|
||||
# 加载 .env 文件
|
||||
load_dotenv()
|
||||
|
||||
# PostgreSQL 连接字符串配置
|
||||
# 优先级:环境变量 DB_URI > Docker 内部服务名 > 本地开发地址
|
||||
DB_URI = os.getenv(
|
||||
"DB_URI",
|
||||
"postgresql://postgres:mysecretpassword@postgres:5432/langgraph_db?sslmode=disable"
|
||||
"postgresql://postgres:mysecretpassword@localhost:5432/langgraph_db?sslmode=disable"
|
||||
)
|
||||
|
||||
|
||||
@@ -25,11 +31,15 @@ DB_URI = os.getenv(
|
||||
async def lifespan(app: FastAPI):
|
||||
"""应用生命周期管理:创建并注入全局服务"""
|
||||
# 1. 创建数据库连接池并初始化表
|
||||
async with AsyncPostgresSaver.from_conn_string(DB_URI) as checkpointer:
|
||||
async with (
|
||||
AsyncPostgresSaver.from_conn_string(DB_URI) as checkpointer,
|
||||
AsyncPostgresStore.from_conn_string(DB_URI) as store
|
||||
):
|
||||
await checkpointer.setup()
|
||||
await store.setup()
|
||||
|
||||
# 2. 构建 AI Agent 服务
|
||||
agent_service = AIAgentService(checkpointer)
|
||||
agent_service = AIAgentService(checkpointer,store)
|
||||
await agent_service.initialize()
|
||||
|
||||
# 3. 将服务实例存入 app.state
|
||||
@@ -39,7 +49,7 @@ async def lifespan(app: FastAPI):
|
||||
yield
|
||||
|
||||
# 4. 关闭时自动清理数据库连接(async with 负责)
|
||||
print("🛑 应用关闭,数据库连接池已释放")
|
||||
info("🛑 应用关闭,数据库连接池已释放")
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
@@ -66,12 +76,17 @@ class ChatRequest(BaseModel):
|
||||
message: str
|
||||
thread_id: str | None = None
|
||||
model: str = "zhipu"
|
||||
user_id: str = "default_user"
|
||||
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
reply: str
|
||||
thread_id: str
|
||||
model_used: str
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
elapsed_time: float = 0.0
|
||||
|
||||
|
||||
# ========== 依赖注入函数 ==========
|
||||
@@ -91,11 +106,27 @@ async def chat_endpoint(
|
||||
raise HTTPException(status_code=400, detail="message required")
|
||||
|
||||
thread_id = request.thread_id or str(uuid.uuid4())
|
||||
reply = await agent_service.process_message(
|
||||
request.message, thread_id, request.model
|
||||
result = await agent_service.process_message(
|
||||
request.message, thread_id, request.model, request.user_id
|
||||
)
|
||||
|
||||
# 提取 token 统计信息
|
||||
token_usage = result.get("token_usage", {})
|
||||
input_tokens = token_usage.get('prompt_tokens', token_usage.get('input_tokens', 0))
|
||||
output_tokens = token_usage.get('completion_tokens', token_usage.get('output_tokens', 0))
|
||||
elapsed_time = result.get("elapsed_time", 0.0)
|
||||
|
||||
actual_model = request.model if request.model in agent_service.graphs else next(iter(agent_service.graphs.keys()))
|
||||
return ChatResponse(reply=reply, thread_id=thread_id, model_used=actual_model)
|
||||
|
||||
return ChatResponse(
|
||||
reply=result["reply"],
|
||||
thread_id=thread_id,
|
||||
model_used=actual_model,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
total_tokens=input_tokens + output_tokens,
|
||||
elapsed_time=elapsed_time
|
||||
)
|
||||
|
||||
|
||||
# ========== WebSocket 端点(可选) ==========
|
||||
@@ -111,10 +142,11 @@ async def websocket_endpoint(
|
||||
message = data.get("message")
|
||||
thread_id = data.get("thread_id", str(uuid.uuid4()))
|
||||
model = data.get("model", "zhipu")
|
||||
user_id = data.get("user_id", "default_user")
|
||||
if not message:
|
||||
await websocket.send_json({"error": "missing message"})
|
||||
continue
|
||||
reply = await agent_service.process_message(message, thread_id, model)
|
||||
reply = await agent_service.process_message(message, thread_id, model, user_id)
|
||||
actual_model = model if model in agent_service.graphs else next(iter(agent_service.graphs.keys()))
|
||||
await websocket.send_json({"reply": reply, "thread_id": thread_id, "model_used": actual_model})
|
||||
except WebSocketDisconnect:
|
||||
|
||||
@@ -4,19 +4,34 @@ LangGraph 状态图构建模块 - 完全面向对象风格,无嵌套函数
|
||||
|
||||
import operator
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Literal, Annotated, Any
|
||||
from langchain_core.language_models import BaseLLM
|
||||
from langchain_core.messages import AnyMessage, AIMessage, ToolMessage, SystemMessage
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from langgraph.graph import StateGraph, START, END
|
||||
from typing_extensions import TypedDict
|
||||
from langgraph.store.postgres.aio import AsyncPostgresStore
|
||||
from langgraph.runtime import Runtime
|
||||
from dataclasses import dataclass
|
||||
import uuid
|
||||
|
||||
# 本地模块
|
||||
from app.logger import debug, info, warning, error
|
||||
|
||||
|
||||
class MessageState(TypedDict):
|
||||
class MessagesState(TypedDict):
|
||||
"""对话状态类型定义"""
|
||||
messages: Annotated[list[AnyMessage], operator.add]
|
||||
llm_calls: int
|
||||
memory_context:str
|
||||
last_token_usage: dict # 本次调用的 token 使用详情
|
||||
last_elapsed_time: float # 本次调用耗时(秒)
|
||||
|
||||
@dataclass
|
||||
class GraphContext:
|
||||
user_id: str
|
||||
# 可扩展更多上下文信息
|
||||
|
||||
class GraphBuilder:
|
||||
"""LangGraph 状态图构建器 - 所有节点均为类方法"""
|
||||
@@ -42,33 +57,132 @@ class GraphBuilder:
|
||||
"""创建系统提示模板(静态方法,无需访问实例)"""
|
||||
return ChatPromptTemplate.from_messages([
|
||||
SystemMessage(content=(
|
||||
"你是一个个人生活助手和数据分析助手。请说中文。"
|
||||
"当用户询问天气或温度时,使用get_current_temperature工具获取信息。"
|
||||
"当用户要求读文本文件时,请使用 read_local_file 工具,只能读取 './user_docs' 目录下的文件。"
|
||||
"当用户要求读PDF文件时,请使用 read_pdf_summary 工具,只能读取 './user_docs' 目录下的文件。"
|
||||
"当用户要求读Excel文件时,请使用 read_excel_as_markdown 工具,只能读取 './user_docs' 目录下的文件。"
|
||||
"当用户要求抓取网页时,请使用 fetch_webpage_content 工具。"
|
||||
"重要:你的回答必须简洁、直接,不要包含任何关于思考过程的描述。"
|
||||
"你是一个个人生活助手和数据分析助手,请使用中文交流。\n\n"
|
||||
"【用户背景信息】\n"
|
||||
"以下是对当前用户的已知信息和长期记忆,你必须优先采纳并在回答中体现:\n"
|
||||
"{memory_context}\n"
|
||||
"若包含姓名、偏好等个人信息,请自然融入回应(例如称呼名字、提及偏好)。\n\n"
|
||||
"【可用工具与使用规则】\n"
|
||||
"- 获取温度/天气:`get_current_temperature`\n"
|
||||
"- 读取文本文件:`read_local_file`(限定目录 `./user_docs`)\n"
|
||||
"- 读取PDF摘要:`read_pdf_summary`(限定目录 `./user_docs`)\n"
|
||||
"- 读取Excel表格:`read_excel_as_markdown`(限定目录 `./user_docs`)\n"
|
||||
"- 抓取网页内容:`fetch_webpage_content`\n"
|
||||
"工具调用时请直接返回所需参数,无需额外说明。\n\n"
|
||||
"【回答要求(必须遵守)】\n"
|
||||
"1. 回答必须简洁、直接,禁止描述任何思考过程或内心活动。\n"
|
||||
"2. 优先利用已知用户信息进行个性化回复。\n"
|
||||
"3. 若无信息可依,礼貌询问或提供通用帮助。"
|
||||
)),
|
||||
MessagesPlaceholder(variable_name="message")
|
||||
MessagesPlaceholder(variable_name="messages")
|
||||
])
|
||||
|
||||
async def call_llm(self, state: MessageState) -> dict:
|
||||
async def call_llm(self, state: MessagesState, runtime: Runtime[GraphContext]) -> dict:
|
||||
"""
|
||||
LLM 调用节点(异步方法)
|
||||
注意:因为 self._chain.invoke 是同步方法,使用 run_in_executor 避免阻塞事件循环
|
||||
"""
|
||||
memory_context = state.get("memory_context", "暂无用户信息")
|
||||
|
||||
# 构建完整的输入消息列表(用于调试打印)
|
||||
system_prompt = self._prompt.messages[0] # SystemMessage
|
||||
if isinstance(system_prompt, SystemMessage):
|
||||
system_content = system_prompt.content.format(memory_context=memory_context)
|
||||
else:
|
||||
system_content = str(system_prompt.content)
|
||||
|
||||
input_messages = [SystemMessage(content=system_content)] + state["messages"]
|
||||
|
||||
# 打印发送给大模型的最终输入
|
||||
debug("\n" + "="*80)
|
||||
debug("📤 [LLM输入] 发送给大模型的完整消息:")
|
||||
debug(f" 总消息数: {len(input_messages)}")
|
||||
for i, msg in enumerate(input_messages):
|
||||
content_preview = str(msg.content) # 不截断,完整输出
|
||||
debug(f" [{i}] {msg.type.upper():10s}: {content_preview}")
|
||||
debug("="*80 + "\n")
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
response = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self._chain.invoke({"message": state["messages"]})
|
||||
)
|
||||
return {
|
||||
"messages": [response],
|
||||
"llm_calls": state.get('llm_calls', 0) + 1
|
||||
}
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
response = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self._chain.invoke({
|
||||
"messages": state["messages"],
|
||||
"memory_context": memory_context
|
||||
})
|
||||
)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# 提取 token 用量(兼容不同 LLM 提供商的元数据格式)
|
||||
token_usage = {}
|
||||
input_tokens = 0
|
||||
output_tokens = 0
|
||||
|
||||
# 尝试从 response_metadata 中提取
|
||||
if hasattr(response, 'response_metadata') and response.response_metadata:
|
||||
meta = response.response_metadata
|
||||
if 'token_usage' in meta:
|
||||
token_usage = meta['token_usage']
|
||||
elif 'usage' in meta:
|
||||
token_usage = meta['usage']
|
||||
|
||||
# 尝试从 additional_kwargs 中提取
|
||||
if not token_usage and hasattr(response, 'additional_kwargs'):
|
||||
add_kwargs = response.additional_kwargs
|
||||
if 'llm_output' in add_kwargs and 'token_usage' in add_kwargs['llm_output']:
|
||||
token_usage = add_kwargs['llm_output']['token_usage']
|
||||
|
||||
# 提取具体的 token 数值
|
||||
if token_usage:
|
||||
input_tokens = token_usage.get('prompt_tokens', token_usage.get('input_tokens', 0))
|
||||
output_tokens = token_usage.get('completion_tokens', token_usage.get('output_tokens', 0))
|
||||
|
||||
# 打印响应统计信息
|
||||
info(f"⏱️ [LLM统计] 调用耗时: {elapsed_time:.2f}秒")
|
||||
info(f"📊 [LLM统计] Token用量: 输入={input_tokens}, 输出={output_tokens}, 总计={input_tokens + output_tokens}")
|
||||
if token_usage:
|
||||
debug(f"📋 [LLM统计] 详细用量: {token_usage}")
|
||||
|
||||
# 打印 LLM 的完整输出
|
||||
debug("\n" + "="*80)
|
||||
debug("📥 [LLM输出] 大模型返回的完整响应:")
|
||||
debug(f" 消息类型: {response.type.upper()}")
|
||||
debug(f" 内容长度: {len(str(response.content))} 字符")
|
||||
debug("-"*80)
|
||||
debug(f"{response.content}")
|
||||
debug("="*80 + "\n")
|
||||
|
||||
return {
|
||||
"messages": [response],
|
||||
"llm_calls": state.get('llm_calls', 0) + 1,
|
||||
"last_token_usage": token_usage,
|
||||
"last_elapsed_time": elapsed_time
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
elapsed_time = time.time() - start_time
|
||||
error(f"\n❌ [LLM错误] 调用失败 (耗时: {elapsed_time:.2f}秒)")
|
||||
error(f" 错误类型: {type(e).__name__}")
|
||||
error(f" 错误信息: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
debug("="*80 + "\n")
|
||||
|
||||
# 返回一个友好的错误消息
|
||||
error_response = AIMessage(
|
||||
content="抱歉,模型暂时无法响应,可能是网络超时或服务繁忙,请稍后再试。"
|
||||
)
|
||||
return {
|
||||
"messages": [error_response],
|
||||
"llm_calls": state.get('llm_calls', 0),
|
||||
"last_token_usage": {},
|
||||
"last_elapsed_time": elapsed_time
|
||||
}
|
||||
|
||||
async def call_tools(self, state: MessageState) -> dict:
|
||||
async def call_tools(self, state: MessagesState, runtime: Runtime[GraphContext]) -> dict:
|
||||
"""
|
||||
工具执行节点(异步方法)
|
||||
对于每个工具调用,在线程池中执行同步工具函数
|
||||
@@ -91,11 +205,15 @@ class GraphBuilder:
|
||||
continue
|
||||
|
||||
try:
|
||||
# 同步工具函数在线程池中执行
|
||||
observation = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: tool_func.invoke(tool_args)
|
||||
)
|
||||
# 修复闭包问题:将变量作为默认参数传入 lambda
|
||||
# 如果工具支持异步 (ainvoke),优先使用异步调用
|
||||
if hasattr(tool_func, 'ainvoke'):
|
||||
observation = await tool_func.ainvoke(tool_args)
|
||||
else:
|
||||
observation = await loop.run_in_executor(
|
||||
None,
|
||||
lambda args=tool_args: tool_func.invoke(args) # 默认参数捕获当前值
|
||||
)
|
||||
results.append(ToolMessage(content=str(observation), tool_call_id=tool_id))
|
||||
except Exception as e:
|
||||
results.append(ToolMessage(content=f"Error: {e}", tool_call_id=tool_id))
|
||||
@@ -103,25 +221,101 @@ class GraphBuilder:
|
||||
return {"messages": results}
|
||||
|
||||
@staticmethod
|
||||
def should_continue(state: MessageState) -> Literal['tool_node', END]:
|
||||
"""
|
||||
条件边判断(静态方法)
|
||||
决定下一步是进入工具节点还是结束
|
||||
"""
|
||||
def should_continue(state: MessagesState) -> Literal['tool_node', 'save_memory', 'END']:
|
||||
"""决定下一步:工具调用、保存记忆还是结束"""
|
||||
last_message = state["messages"][-1]
|
||||
if isinstance(last_message, AIMessage) and bool(last_message.tool_calls):
|
||||
|
||||
# 1. 如果需要调用工具,优先进入工具节点
|
||||
if isinstance(last_message, AIMessage) and last_message.tool_calls:
|
||||
return 'tool_node'
|
||||
return END
|
||||
|
||||
# 2. 如果是 AI 的最终回复,可以考虑进入记忆保存节点(可增加判断逻辑)
|
||||
# 这里简单处理:只要没有工具调用,且是 AI 消息,就尝试保存记忆。
|
||||
if isinstance(last_message, AIMessage):
|
||||
return 'save_memory'
|
||||
|
||||
# 3. 其他情况(如只有用户消息)直接结束
|
||||
return 'END'
|
||||
|
||||
async def retrieve_memory(self, state: MessagesState, runtime: Runtime[GraphContext]) -> dict:
|
||||
"""搜索并返回长期记忆"""
|
||||
user_id = runtime.context.user_id
|
||||
namespace = ("memories", user_id)
|
||||
query = str(state["messages"][-1].content)
|
||||
|
||||
debug(f"\n{'='*60}")
|
||||
debug(f"🔎 [记忆检索] 开始检索")
|
||||
debug(f" ├─ 用户ID: {user_id}")
|
||||
debug(f" ├─ 命名空间: {namespace}")
|
||||
debug(f" ├─ 查询内容: '{query}'")
|
||||
debug(f" └─ 消息总数: {len(state['messages'])}")
|
||||
|
||||
try:
|
||||
memories = await runtime.store.asearch(namespace, query=query)
|
||||
debug(f"✅ [记忆检索] 检索完成,找到 {len(memories)} 条相关记忆")
|
||||
|
||||
if memories:
|
||||
memory_text = "\n".join([m.value["data"] for m in memories])
|
||||
debug(f"📚 [记忆内容]")
|
||||
for i, memory in enumerate(memories, 1):
|
||||
debug(f" [{i}] {memory.value['data']}")
|
||||
debug(f"{'='*60}\n")
|
||||
return {"memory_context": memory_text}
|
||||
else:
|
||||
debug(f"⚠️ [记忆检索] 未找到相关记忆")
|
||||
debug(f"{'='*60}\n")
|
||||
return {"memory_context": ""}
|
||||
|
||||
except Exception as e:
|
||||
error(f"❌ [记忆检索] 检索失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
debug(f"{'='*60}\n")
|
||||
return {"memory_context": ""}
|
||||
|
||||
async def save_memory(self, state: MessagesState, runtime: Runtime[GraphContext]) -> dict:
|
||||
"""尝试从对话中提取并保存长期记忆"""
|
||||
# 获取最后一条用户消息(通常是要记住的内容的来源)
|
||||
user_messages = [msg for msg in state["messages"] if msg.type == "human"]
|
||||
if not user_messages:
|
||||
return {}
|
||||
|
||||
last_user_msg = user_messages[-1].content.lower()
|
||||
|
||||
# 简单触发逻辑:包含"记住"或"保存"等关键词
|
||||
if any(keyword in last_user_msg for keyword in ["记住", "保存", "别忘了"]):
|
||||
# 提取记忆内容(这里仅作示例,实际可用 LLM 提取)
|
||||
memory_content = f"用户说过:{last_user_msg}"
|
||||
user_id = runtime.context.user_id
|
||||
namespace = ("memories", user_id)
|
||||
await runtime.store.aput(namespace, str(uuid.uuid4()), {"data": memory_content})
|
||||
info(f"✅ 长期记忆已保存:{memory_content}")
|
||||
|
||||
return {}
|
||||
|
||||
def build(self) -> StateGraph:
|
||||
"""
|
||||
构建未编译的状态图(返回 StateGraph 实例)
|
||||
图中节点直接使用实例方法 call_llm, call_tools
|
||||
"""
|
||||
builder = StateGraph(MessageState)
|
||||
builder = StateGraph(MessagesState,context_schema=GraphContext)
|
||||
builder.add_node("retrieve_memory", self.retrieve_memory)
|
||||
builder.add_node("llm_call", self.call_llm)
|
||||
builder.add_node("tool_node", self.call_tools)
|
||||
builder.add_edge(START, "llm_call")
|
||||
builder.add_conditional_edges("llm_call", self.should_continue, ["tool_node", END])
|
||||
builder.add_node("save_memory", self.save_memory)
|
||||
|
||||
builder.add_edge(START, "retrieve_memory")
|
||||
builder.add_edge("retrieve_memory", "llm_call")
|
||||
builder.add_conditional_edges(
|
||||
"llm_call",
|
||||
self.should_continue,
|
||||
{
|
||||
"tool_node": "tool_node",
|
||||
"save_memory": "save_memory",
|
||||
'END': END
|
||||
}
|
||||
)
|
||||
builder.add_edge("tool_node", "llm_call")
|
||||
builder.add_edge("save_memory", END)
|
||||
|
||||
return builder
|
||||
55
app/logger.py
Normal file
55
app/logger.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""
|
||||
统一的日志模块 - 基于环境变量控制日志级别
|
||||
类似 C# 的条件编译效果,开发时打印详细调试信息,生产环境只输出关键信息
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from typing import Any
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 先加载环境变量
|
||||
load_dotenv()
|
||||
|
||||
# 从环境变量读取日志级别,默认 INFO
|
||||
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
|
||||
|
||||
# 根据环境变量控制是否显示详细调试信息
|
||||
DEBUG_MODE = os.getenv("DEBUG", "false").lower() == "true"
|
||||
|
||||
# 创建统一的日志器
|
||||
logger = logging.getLogger("ai_agent")
|
||||
logger.setLevel(getattr(logging, LOG_LEVEL, logging.INFO))
|
||||
|
||||
# 避免重复添加 handler
|
||||
if not logger.handlers:
|
||||
handler = logging.StreamHandler()
|
||||
# 重要:handler 也需要设置级别,否则可能继承根 logger 的级别
|
||||
handler.setLevel(getattr(logging, LOG_LEVEL, logging.INFO))
|
||||
formatter = logging.Formatter(
|
||||
fmt="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
|
||||
|
||||
def debug(msg: Any, *args, **kwargs):
|
||||
"""调试日志,仅在 DEBUG 环境变量为 true 时打印"""
|
||||
if DEBUG_MODE:
|
||||
logger.debug(msg, *args, **kwargs)
|
||||
|
||||
|
||||
def info(msg: Any, *args, **kwargs):
|
||||
"""信息日志"""
|
||||
logger.info(msg, *args, **kwargs)
|
||||
|
||||
|
||||
def warning(msg: Any, *args, **kwargs):
|
||||
"""警告日志"""
|
||||
logger.warning(msg, *args, **kwargs)
|
||||
|
||||
|
||||
def error(msg: Any, *args, **kwargs):
|
||||
"""错误日志"""
|
||||
logger.error(msg, *args, **kwargs)
|
||||
Reference in New Issue
Block a user