docs(.gitignore/README/QUICKSTART): 更新文档和忽略配置 - 添加IDE配置、日志和数据文件到.gitignore - 重构QUICKSTART.md,提供Docker Compose和本地开发两种部署方式 - 更新README.md,优化项目介绍和架构说明 - 移除旧的agent.py和backend.py文件 ```
This commit is contained in:
8
app/__init__.py
Normal file
8
app/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""
|
||||
AI Agent 应用模块
|
||||
"""
|
||||
|
||||
from .agent import AIAgentService
|
||||
from .tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
|
||||
|
||||
__all__ = ["AIAgentService", "AVAILABLE_TOOLS", "TOOLS_BY_NAME"]
|
||||
87
app/agent.py
Normal file
87
app/agent.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""
|
||||
AI Agent 服务类 - 支持多模型动态切换
|
||||
接收外部传入的 checkpointer,不负责管理连接生命周期
|
||||
"""
|
||||
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from langchain_community.chat_models import ChatZhipuAI
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_openai import ChatOpenAI
|
||||
from pydantic import SecretStr
|
||||
|
||||
# 本地模块
|
||||
from app.graph_builder import GraphBuilder
|
||||
from app.tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class AIAgentService:
|
||||
"""异步 AI Agent 服务,支持多模型动态切换,使用外部传入的 checkpointer"""
|
||||
|
||||
def __init__(self, checkpointer):
|
||||
"""
|
||||
初始化服务
|
||||
Args:
|
||||
checkpointer: 已经初始化的 AsyncPostgresSaver 实例
|
||||
"""
|
||||
self.checkpointer = checkpointer
|
||||
self.graphs = {} # 存储不同模型对应的 graph 实例
|
||||
|
||||
def _create_zhipu_llm(self):
|
||||
"""创建智谱在线 LLM"""
|
||||
api_key = os.getenv("ZHIPUAI_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError("ZHIPUAI_API_KEY not set in environment")
|
||||
return ChatZhipuAI(
|
||||
model="glm-4.7-flash",
|
||||
api_key=api_key,
|
||||
temperature=0.1,
|
||||
max_tokens=4096,
|
||||
)
|
||||
|
||||
def _create_local_llm(self):
|
||||
"""创建本地 vLLM 服务 LLM"""
|
||||
return ChatOpenAI(
|
||||
# 原来是 http://localhost:8000/v1
|
||||
# 改为 FRP 穿透后的公网地址
|
||||
base_url = "http://115.190.121.151:18000/v1",
|
||||
api_key=SecretStr(os.getenv("VLLM_LOCAL_KEY", "")),
|
||||
model="gemma-4-E2B-it",
|
||||
)
|
||||
|
||||
async def initialize(self):
|
||||
"""预编译所有模型的 graph(使用传入的 checkpointer)"""
|
||||
model_configs = {
|
||||
"zhipu": self._create_zhipu_llm,
|
||||
"local": self._create_local_llm,
|
||||
}
|
||||
|
||||
for model_name, llm_creator in model_configs.items():
|
||||
try:
|
||||
llm = llm_creator()
|
||||
builder = GraphBuilder(llm, AVAILABLE_TOOLS, TOOLS_BY_NAME).build()
|
||||
graph = builder.compile(checkpointer=self.checkpointer)
|
||||
self.graphs[model_name] = graph
|
||||
print(f"✅ 模型 '{model_name}' 初始化成功")
|
||||
except Exception as e:
|
||||
print(f"⚠️ 模型 '{model_name}' 初始化失败: {e}")
|
||||
|
||||
if not self.graphs:
|
||||
raise RuntimeError("没有可用的模型,请检查配置")
|
||||
|
||||
return self
|
||||
|
||||
async def process_message(self, message: str, thread_id: str, model: str = "zhipu") -> str:
|
||||
"""处理用户消息,返回最终答案"""
|
||||
if model not in self.graphs:
|
||||
fallback_model = next(iter(self.graphs.keys()))
|
||||
print(f"警告: 模型 '{model}' 不可用,已切换到 '{fallback_model}'")
|
||||
model = fallback_model
|
||||
|
||||
graph = self.graphs[model]
|
||||
config = {"configurable": {"thread_id": thread_id}}
|
||||
input_state = {"messages": [HumanMessage(content=message)]}
|
||||
result = await graph.ainvoke(input_state, config=config)
|
||||
return result["messages"][-1].content
|
||||
115
app/backend.py
Normal file
115
app/backend.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""
|
||||
FastAPI 后端 - 支持动态模型切换,使用 PostgreSQL 持久化记忆
|
||||
采用依赖注入模式,优雅管理资源生命周期
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect, Depends, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
|
||||
from app.agent import AIAgentService
|
||||
|
||||
# PostgreSQL 连接字符串
|
||||
DB_URI = "postgresql://postgres:mysecretpassword@localhost:5432/langgraph_db?sslmode=disable"
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""应用生命周期管理:创建并注入全局服务"""
|
||||
# 1. 创建数据库连接池并初始化表
|
||||
async with AsyncPostgresSaver.from_conn_string(DB_URI) as checkpointer:
|
||||
await checkpointer.setup()
|
||||
|
||||
# 2. 构建 AI Agent 服务
|
||||
agent_service = AIAgentService(checkpointer)
|
||||
await agent_service.initialize()
|
||||
|
||||
# 3. 将服务实例存入 app.state
|
||||
app.state.agent_service = agent_service
|
||||
|
||||
# 应用运行中...
|
||||
yield
|
||||
|
||||
# 4. 关闭时自动清理数据库连接(async with 负责)
|
||||
print("🛑 应用关闭,数据库连接池已释放")
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
# CORS 中间件(允许前端跨域)
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
# ========== Pydantic 模型 ==========
|
||||
class ChatRequest(BaseModel):
|
||||
message: str
|
||||
thread_id: str | None = None
|
||||
model: str = "zhipu"
|
||||
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
reply: str
|
||||
thread_id: str
|
||||
model_used: str
|
||||
|
||||
|
||||
# ========== 依赖注入函数 ==========
|
||||
def get_agent_service(request: Request) -> AIAgentService:
|
||||
"""从 app.state 中获取全局 AIAgentService 实例"""
|
||||
return request.app.state.agent_service
|
||||
|
||||
|
||||
# ========== HTTP 端点 ==========
|
||||
@app.post("/chat", response_model=ChatResponse)
|
||||
async def chat_endpoint(
|
||||
request: ChatRequest,
|
||||
agent_service: AIAgentService = Depends(get_agent_service)
|
||||
):
|
||||
"""同步对话接口,支持模型选择"""
|
||||
if not request.message:
|
||||
raise HTTPException(status_code=400, detail="message required")
|
||||
|
||||
thread_id = request.thread_id or str(uuid.uuid4())
|
||||
reply = await agent_service.process_message(
|
||||
request.message, thread_id, request.model
|
||||
)
|
||||
actual_model = request.model if request.model in agent_service.graphs else next(iter(agent_service.graphs.keys()))
|
||||
return ChatResponse(reply=reply, thread_id=thread_id, model_used=actual_model)
|
||||
|
||||
|
||||
# ========== WebSocket 端点(可选) ==========
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(
|
||||
websocket: WebSocket,
|
||||
agent_service: AIAgentService = Depends(get_agent_service)
|
||||
):
|
||||
await websocket.accept()
|
||||
try:
|
||||
while True:
|
||||
data = await websocket.receive_json()
|
||||
message = data.get("message")
|
||||
thread_id = data.get("thread_id", str(uuid.uuid4()))
|
||||
model = data.get("model", "zhipu")
|
||||
if not message:
|
||||
await websocket.send_json({"error": "missing message"})
|
||||
continue
|
||||
reply = await agent_service.process_message(message, thread_id, model)
|
||||
actual_model = model if model in agent_service.graphs else next(iter(agent_service.graphs.keys()))
|
||||
await websocket.send_json({"reply": reply, "thread_id": thread_id, "model_used": actual_model})
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8001)
|
||||
127
app/graph_builder.py
Normal file
127
app/graph_builder.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""
|
||||
LangGraph 状态图构建模块 - 完全面向对象风格,无嵌套函数
|
||||
"""
|
||||
|
||||
import operator
|
||||
import asyncio
|
||||
from typing import Literal, Annotated, Any
|
||||
from langchain_core.language_models import BaseLLM
|
||||
from langchain_core.messages import AnyMessage, AIMessage, ToolMessage, SystemMessage
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from langgraph.graph import StateGraph, START, END
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
class MessageState(TypedDict):
|
||||
"""对话状态类型定义"""
|
||||
messages: Annotated[list[AnyMessage], operator.add]
|
||||
llm_calls: int
|
||||
|
||||
|
||||
class GraphBuilder:
|
||||
"""LangGraph 状态图构建器 - 所有节点均为类方法"""
|
||||
|
||||
def __init__(self, llm: BaseLLM, tools: list, tools_by_name: dict[str, Any]):
|
||||
"""
|
||||
初始化构建器
|
||||
|
||||
Args:
|
||||
llm: 大语言模型实例
|
||||
tools: 工具列表
|
||||
tools_by_name: 名称到工具函数的映射
|
||||
"""
|
||||
self.llm = llm
|
||||
self.tools = tools
|
||||
self.tools_by_name = tools_by_name
|
||||
self._llm_with_tools = llm.bind_tools(tools)
|
||||
self._prompt = self._create_prompt()
|
||||
self._chain = self._prompt | self._llm_with_tools
|
||||
|
||||
@staticmethod
|
||||
def _create_prompt() -> ChatPromptTemplate:
|
||||
"""创建系统提示模板(静态方法,无需访问实例)"""
|
||||
return ChatPromptTemplate.from_messages([
|
||||
SystemMessage(content=(
|
||||
"你是一个个人生活助手和数据分析助手。请说中文。"
|
||||
"当用户询问天气或温度时,使用get_current_temperature工具获取信息。"
|
||||
"当用户要求读文本文件时,请使用 read_local_file 工具,只能读取 './user_docs' 目录下的文件。"
|
||||
"当用户要求读PDF文件时,请使用 read_pdf_summary 工具,只能读取 './user_docs' 目录下的文件。"
|
||||
"当用户要求读Excel文件时,请使用 read_excel_as_markdown 工具,只能读取 './user_docs' 目录下的文件。"
|
||||
"当用户要求抓取网页时,请使用 fetch_webpage_content 工具。"
|
||||
"重要:你的回答必须简洁、直接,不要包含任何关于思考过程的描述。"
|
||||
)),
|
||||
MessagesPlaceholder(variable_name="message")
|
||||
])
|
||||
|
||||
async def call_llm(self, state: MessageState) -> dict:
|
||||
"""
|
||||
LLM 调用节点(异步方法)
|
||||
注意:因为 self._chain.invoke 是同步方法,使用 run_in_executor 避免阻塞事件循环
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
response = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self._chain.invoke({"message": state["messages"]})
|
||||
)
|
||||
return {
|
||||
"messages": [response],
|
||||
"llm_calls": state.get('llm_calls', 0) + 1
|
||||
}
|
||||
|
||||
async def call_tools(self, state: MessageState) -> dict:
|
||||
"""
|
||||
工具执行节点(异步方法)
|
||||
对于每个工具调用,在线程池中执行同步工具函数
|
||||
"""
|
||||
last_message = state['messages'][-1]
|
||||
if not isinstance(last_message, AIMessage) or not last_message.tool_calls:
|
||||
return {"messages": []}
|
||||
|
||||
results = []
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
for tool_call in last_message.tool_calls:
|
||||
tool_name = tool_call["name"]
|
||||
tool_args = tool_call["args"]
|
||||
tool_id = tool_call["id"]
|
||||
tool_func = self.tools_by_name.get(tool_name)
|
||||
|
||||
if tool_func is None:
|
||||
results.append(ToolMessage(content=f"Tool {tool_name} not found", tool_call_id=tool_id))
|
||||
continue
|
||||
|
||||
try:
|
||||
# 同步工具函数在线程池中执行
|
||||
observation = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: tool_func.invoke(tool_args)
|
||||
)
|
||||
results.append(ToolMessage(content=str(observation), tool_call_id=tool_id))
|
||||
except Exception as e:
|
||||
results.append(ToolMessage(content=f"Error: {e}", tool_call_id=tool_id))
|
||||
|
||||
return {"messages": results}
|
||||
|
||||
@staticmethod
|
||||
def should_continue(state: MessageState) -> Literal['tool_node', END]:
|
||||
"""
|
||||
条件边判断(静态方法)
|
||||
决定下一步是进入工具节点还是结束
|
||||
"""
|
||||
last_message = state["messages"][-1]
|
||||
if isinstance(last_message, AIMessage) and bool(last_message.tool_calls):
|
||||
return 'tool_node'
|
||||
return END
|
||||
|
||||
def build(self) -> StateGraph:
|
||||
"""
|
||||
构建未编译的状态图(返回 StateGraph 实例)
|
||||
图中节点直接使用实例方法 call_llm, call_tools
|
||||
"""
|
||||
builder = StateGraph(MessageState)
|
||||
builder.add_node("llm_call", self.call_llm)
|
||||
builder.add_node("tool_node", self.call_tools)
|
||||
builder.add_edge(START, "llm_call")
|
||||
builder.add_conditional_edges("llm_call", self.should_continue, ["tool_node", END])
|
||||
builder.add_edge("tool_node", "llm_call")
|
||||
return builder
|
||||
103
app/tools.py
Normal file
103
app/tools.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""
|
||||
工具定义模块 - 纯函数工具,无依赖 AIAgent 类
|
||||
"""
|
||||
|
||||
# 标准库
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# 第三方库
|
||||
import pandas as pd
|
||||
import pypdf
|
||||
import requests
|
||||
from bs4 import BeautifulSoup
|
||||
from langchain_core.tools import tool
|
||||
|
||||
|
||||
def _file_allow_check(filename: str) -> Path:
|
||||
"""检查用户文件名是否位于允许目录 './user_docs' 下,防止路径遍历攻击。"""
|
||||
allowed_dir = Path("./user_docs").resolve()
|
||||
allowed_dir.mkdir(exist_ok=True)
|
||||
|
||||
file_path = (allowed_dir / filename).resolve()
|
||||
if not str(file_path).startswith(str(allowed_dir)):
|
||||
raise ValueError("错误:非法文件路径。")
|
||||
|
||||
if not file_path.exists():
|
||||
raise FileNotFoundError(f"错误:文件 '{filename}' 不存在。")
|
||||
|
||||
return file_path
|
||||
|
||||
|
||||
@tool
|
||||
def get_current_temperature(location: str) -> str:
|
||||
"""获取指定地点的当前温度。"""
|
||||
return f'当前{location}的温度为25℃'
|
||||
|
||||
|
||||
@tool
|
||||
def read_local_file(filename: str) -> str:
|
||||
"""读取用户指定名称的本地文本文件内容并返回摘要。"""
|
||||
try:
|
||||
file_path = _file_allow_check(filename)
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
return f"文件 '{filename}' 的内容开头:\n{content[:1000]}..."
|
||||
except Exception as e:
|
||||
return f"读取文件时出错:{str(e)}"
|
||||
|
||||
|
||||
@tool
|
||||
def read_pdf_summary(filename: str) -> str:
|
||||
"""读取PDF文件并返回内容文本摘要。"""
|
||||
try:
|
||||
file_path = _file_allow_check(filename)
|
||||
text = ""
|
||||
with open(file_path, 'rb') as f:
|
||||
reader = pypdf.PdfReader(f)
|
||||
for page in reader.pages[:3]:
|
||||
text += page.extract_text()
|
||||
return f"PDF文件 '{filename}' 的前几页内容:\n{text[:2000]}..."
|
||||
except Exception as e:
|
||||
return f"读取PDF出错:{e}"
|
||||
|
||||
|
||||
@tool
|
||||
def read_excel_as_markdown(filename: str) -> str:
|
||||
"""读取Excel文件,并将其主要数据转换为Markdown表格格式。"""
|
||||
try:
|
||||
file_path = _file_allow_check(filename)
|
||||
df = pd.read_excel(file_path)
|
||||
markdown_table = df.head(10).to_markdown(index=False)
|
||||
return f"Excel文件 '{filename}' 的数据预览(前10行):\n{markdown_table}"
|
||||
except Exception as e:
|
||||
return f"读取Excel出错:{e}"
|
||||
|
||||
|
||||
@tool
|
||||
def fetch_webpage_content(url: str) -> str:
|
||||
"""抓取给定URL的网页正文内容,并返回清晰的纯文本。"""
|
||||
try:
|
||||
response = requests.get(url, timeout=10)
|
||||
response.raise_for_status()
|
||||
soup = BeautifulSoup(response.text, 'html.parser')
|
||||
for script in soup(["script", "style"]):
|
||||
script.decompose()
|
||||
text = soup.get_text()
|
||||
lines = (line.strip() for line in text.splitlines())
|
||||
chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
|
||||
text = '\n'.join(chunk for chunk in chunks if chunk)
|
||||
return f"成功抓取网页 {url},正文内容开头:\n{text[:1500]}..."
|
||||
except Exception as e:
|
||||
return f"抓取网页时出错:{str(e)}"
|
||||
|
||||
|
||||
# 工具列表和映射(全局常量)
|
||||
AVAILABLE_TOOLS = [
|
||||
get_current_temperature,
|
||||
read_local_file,
|
||||
fetch_webpage_content,
|
||||
read_pdf_summary,
|
||||
read_excel_as_markdown
|
||||
]
|
||||
TOOLS_BY_NAME = {tool.name: tool for tool in AVAILABLE_TOOLS}
|
||||
Reference in New Issue
Block a user