Files
ailine/backend/app/backend.py
root 3bc9b19bab
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 6m3s
feat: 添加子图API端点和前端测试面板,包含确定取消继续交互
2026-04-25 19:38:22 +08:00

585 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
FastAPI 后端 - 支持动态模型切换,使用 PostgreSQL 持久化记忆
采用依赖注入模式,优雅管理资源生命周期
"""
import os
from .config import DB_URI, BACKEND_PORT
import uuid
import json
from contextlib import asynccontextmanager
from typing import Optional
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect, Depends, Request, Query
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from .agent.service import AIAgentService
from .agent.history import ThreadHistoryService
from .agent_subgraphs.common.human_review import (
ReviewManager,
InMemoryReviewStore,
ReviewStatus,
HumanReview
)
from .logger import info, error
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用生命周期管理:创建并注入全局服务"""
# 1. 创建数据库连接池并初始化表(仅 checkpointer
async with AsyncPostgresSaver.from_conn_string(DB_URI) as checkpointer:
await checkpointer.setup()
# 2. 构建 AI Agent 服务
agent_service = AIAgentService(checkpointer)
await agent_service.initialize()
# 3. 创建历史查询服务
history_service = ThreadHistoryService(checkpointer)
# 4. 创建审核管理器
review_manager = ReviewManager(InMemoryReviewStore())
# 5. 将服务实例存入 app.state
app.state.agent_service = agent_service
app.state.history_service = history_service
app.state.review_manager = review_manager
# 应用运行中...
yield
# 6. 关闭时自动清理数据库连接async with 负责)
info("🛑 应用关闭,数据库连接池已释放")
app = FastAPI(lifespan=lifespan)
# CORS 中间件(允许前端跨域)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ========== 健康检查端点 ==========
@app.get("/health")
async def health_check():
"""健康检查端点,用于 Docker 和 CI/CD 监控"""
return {"status": "ok", "service": "ai-agent-backend"}
# ========== Pydantic 模型 ==========
class ChatRequest(BaseModel):
message: str
thread_id: str | None = None
model: str = "zhipu"
user_id: str = "default_user"
class ChatResponse(BaseModel):
reply: str
thread_id: str
model_used: str
input_tokens: int = 0
output_tokens: int = 0
total_tokens: int = 0
elapsed_time: float = 0.0
class ReviewActionRequest(BaseModel):
review_id: str
reviewer: str
comment: str = ""
modified_content: str = ""
class ReviewResponse(BaseModel):
review_id: str
thread_id: str
user_id: str
status: str
content_to_review: str
review_comment: str = ""
modified_content: str = ""
created_at: str
reviewed_at: Optional[str] = None
# ========== 依赖注入函数 ==========
def get_agent_service(request: Request) -> AIAgentService:
"""从 app.state 中获取全局 AIAgentService 实例"""
return request.app.state.agent_service
def get_history_service(request: Request) -> ThreadHistoryService:
"""从 app.state 中获取全局 ThreadHistoryService 实例"""
return request.app.state.history_service
def get_review_manager(request: Request) -> ReviewManager:
"""从 app.state 中获取全局 ReviewManager 实例"""
return request.app.state.review_manager
# ========== HTTP 端点 ==========
@app.post("/chat", response_model=ChatResponse)
async def chat_endpoint(
request: ChatRequest,
agent_service: AIAgentService = Depends(get_agent_service)
):
"""同步对话接口,支持模型选择"""
if not request.message:
raise HTTPException(status_code=400, detail="message required")
thread_id = request.thread_id or str(uuid.uuid4())
result = await agent_service.process_message(
request.message, thread_id, request.model, request.user_id
)
# 提取 token 统计信息
token_usage = result.get("token_usage", {})
input_tokens = token_usage.get('prompt_tokens', token_usage.get('input_tokens', 0))
output_tokens = token_usage.get('completion_tokens', token_usage.get('output_tokens', 0))
elapsed_time = result.get("elapsed_time", 0.0)
actual_model = request.model if request.model in agent_service.graphs else next(iter(agent_service.graphs.keys()))
return ChatResponse(
reply=result["reply"],
thread_id=thread_id,
model_used=actual_model,
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=input_tokens + output_tokens,
elapsed_time=elapsed_time
)
# ========== 历史查询接口 ==========
@app.get("/threads")
async def list_threads(
user_id: str = Query("default_user", description="用户 ID"),
limit: int = Query(50, ge=1, le=200, description="返回数量限制"),
history_service: ThreadHistoryService = Depends(get_history_service)
):
"""获取当前用户的对话历史列表"""
threads = await history_service.get_user_threads(user_id, limit)
return {"threads": threads}
@app.get("/thread/{thread_id}/messages")
async def get_thread_messages(
thread_id: str,
user_id: str = Query("default_user", description="用户 ID"),
history_service: ThreadHistoryService = Depends(get_history_service)
):
"""获取指定线程的完整消息历史"""
messages = await history_service.get_thread_messages(thread_id)
return {"messages": messages}
@app.get("/thread/{thread_id}/summary")
async def get_thread_summary(
thread_id: str,
user_id: str = Query("default_user", description="用户 ID"),
history_service: ThreadHistoryService = Depends(get_history_service)
):
"""获取指定线程的摘要信息"""
summary = await history_service.get_thread_summary(thread_id)
return summary
# ========== 流式对话接口 ==========
@app.post("/chat/stream")
async def chat_stream_endpoint(
request: ChatRequest,
agent_service: AIAgentService = Depends(get_agent_service)
):
"""流式对话接口SSE"""
if not request.message:
raise HTTPException(status_code=400, detail="message required")
thread_id = request.thread_id or str(uuid.uuid4())
async def event_generator():
try:
async for chunk in agent_service.process_message_stream(
request.message, thread_id, request.model, request.user_id
):
yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
except Exception as e:
error(f"流式响应异常: {e}")
yield f"data: {json.dumps({'type': 'error', 'message': str(e)}, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no", # 禁用 Nginx 缓冲
}
)
# ========== WebSocket 端点(可选) ==========
@app.websocket("/ws")
async def websocket_endpoint(
websocket: WebSocket,
agent_service: AIAgentService = Depends(get_agent_service)
):
await websocket.accept()
try:
while True:
data = await websocket.receive_json()
message = data.get("message")
thread_id = data.get("thread_id", str(uuid.uuid4()))
model = data.get("model", "zhipu")
user_id = data.get("user_id", "default_user")
if not message:
await websocket.send_json({"error": "missing message"})
continue
reply = await agent_service.process_message(message, thread_id, model, user_id)
actual_model = model if model in agent_service.graphs else next(iter(agent_service.graphs.keys()))
await websocket.send_json({"reply": reply, "thread_id": thread_id, "model_used": actual_model})
except WebSocketDisconnect:
pass
# ========== 审核相关端点 ==========
def review_to_response(review: HumanReview) -> ReviewResponse:
"""将审核对象转换为响应对象"""
return ReviewResponse(
review_id=review.review_id,
thread_id=review.thread_id,
user_id=review.user_id,
status=review.status.name,
content_to_review=review.content_to_review,
review_comment=review.review_comment,
modified_content=review.modified_content,
created_at=review.created_at.isoformat(),
reviewed_at=review.reviewed_at.isoformat() if review.reviewed_at else None
)
@app.get("/reviews/pending", response_model=list[ReviewResponse])
async def get_pending_reviews(
limit: int = Query(100, ge=1, le=500, description="返回数量限制"),
review_manager: ReviewManager = Depends(get_review_manager)
):
"""获取待审核列表"""
reviews = review_manager.get_pending_reviews(limit)
return [review_to_response(review) for review in reviews]
@app.get("/reviews/{review_id}", response_model=ReviewResponse)
async def get_review(
review_id: str,
review_manager: ReviewManager = Depends(get_review_manager)
):
"""获取审核详情"""
review = review_manager.get_review(review_id)
if not review:
raise HTTPException(status_code=404, detail="Review not found")
return review_to_response(review)
@app.get("/reviews/thread/{thread_id}", response_model=list[ReviewResponse])
async def get_thread_reviews(
thread_id: str,
review_manager: ReviewManager = Depends(get_review_manager)
):
"""获取线程的所有审核"""
# 注意:我们的 ReviewStore 接口目前没有 get_by_thread 方法暴露在 ReviewManager 中
# 这里我们直接访问 store但在实际项目中应该在 ReviewManager 中添加这个方法
reviews = review_manager.store.get_by_thread(thread_id) if hasattr(review_manager.store, 'get_by_thread') else []
return [review_to_response(review) for review in reviews]
@app.post("/reviews/{review_id}/approve")
async def approve_review(
review_id: str,
request: ReviewActionRequest,
review_manager: ReviewManager = Depends(get_review_manager)
):
"""审核通过"""
success = review_manager.approve(
review_id=review_id,
reviewer=request.reviewer,
comment=request.comment
)
if not success:
raise HTTPException(status_code=404, detail="Review not found")
return {"status": "success", "review_id": review_id}
@app.post("/reviews/{review_id}/reject")
async def reject_review(
review_id: str,
request: ReviewActionRequest,
review_manager: ReviewManager = Depends(get_review_manager)
):
"""审核拒绝"""
success = review_manager.reject(
review_id=review_id,
reviewer=request.reviewer,
comment=request.comment
)
if not success:
raise HTTPException(status_code=404, detail="Review not found")
return {"status": "success", "review_id": review_id}
@app.post("/reviews/{review_id}/modify")
async def modify_review(
review_id: str,
request: ReviewActionRequest,
review_manager: ReviewManager = Depends(get_review_manager)
):
"""审核修改"""
if not request.modified_content:
raise HTTPException(status_code=400, detail="modified_content required")
success = review_manager.modify(
review_id=review_id,
reviewer=request.reviewer,
modified_content=request.modified_content,
comment=request.comment
)
if not success:
raise HTTPException(status_code=404, detail="Review not found")
return {"status": "success", "review_id": review_id}
@app.post("/reviews/request")
async def request_review(
thread_id: str,
user_id: str,
content: str,
review_manager: ReviewManager = Depends(get_review_manager)
):
"""请求审核(测试用)"""
review_id = review_manager.request_review(thread_id, user_id, content)
return {"status": "success", "review_id": review_id}
if __name__ == "__main__":
import uvicorn
# 使用环境变量或默认端口 8079避免与 llama.cpp 的 8081 端口冲突)
port = int(BACKEND_PORT)
uvicorn.run(app, host="0.0.0.0", port=port)
# ==================== 子图专用 API 端点 ====================
# 简化版本,直接调用各个子图,无需完整 agent_service
# 注意:这些是独立测试用的简化端点,方便前端直接调用
@app.get("/subgraph/dictionary/{action}")
async def dictionary_subgraph_api(
action: str,
query: str = "",
user_id: str = "default"
):
"""词典子图简化 API"""
from backend.app.agent_subgraphs.dictionary import (
DictionaryState,
DictionaryAction,
parse_intent,
format_result
)
from backend.app.agent_subgraphs.dictionary.nodes import (
query_word, translate_text, extract_terms, get_daily_word
)
# 创建初始状态
state = DictionaryState(user_query=query, user_id=user_id)
# 处理 action
try:
if action == "query":
state.action = DictionaryAction.QUERY
state.action_params = {"word": query}
state = query_word(state)
elif action == "translate":
state.action = DictionaryAction.TRANSLATE
state.source_text = query
state = translate_text(state)
elif action == "daily":
state.action = DictionaryAction.DAILY_WORD
state = get_daily_word(state)
elif action == "extract":
state.action = DictionaryAction.EXTRACT
state.action_params = {"text": query}
state = extract_terms(state)
else:
# 自动解析意图
state = parse_intent(state)
# 根据解析后的 action 调用
if state.action == DictionaryAction.QUERY:
state = query_word(state)
elif state.action == DictionaryAction.TRANSLATE:
state = translate_text(state)
elif state.action == DictionaryAction.DAILY_WORD:
state = get_daily_word(state)
elif state.action == DictionaryAction.EXTRACT:
state = extract_terms(state)
# 格式化结果
state = format_result(state)
return {
"success": True,
"action": str(state.action),
"result": state.final_result,
"raw_data": {
"word_entry": vars(state.word_entry) if state.word_entry else None,
"translated_text": state.translated_text,
"extracted_terms": [vars(t) for t in state.extracted_terms],
"daily_word": vars(state.daily_word) if state.daily_word else None
}
}
except Exception as e:
return {"success": False, "error": str(e)}
@app.get("/subgraph/news/{action}")
async def news_subgraph_api(
action: str,
query: str = "",
user_id: str = "default"
):
"""资讯子图简化 API"""
from backend.app.agent_subgraphs.news_analysis import (
NewsAnalysisState,
NewsAction,
parse_intent,
format_result
)
from backend.app.agent_subgraphs.news_analysis.nodes import (
query_news, analyze_url, extract_keywords, generate_report
)
# 创建初始状态
state = NewsAnalysisState(user_query=query, user_id=user_id)
# 处理 action
try:
if action == "query":
state.action = NewsAction.QUERY_NEWS
state = query_news(state)
elif action == "analyze":
state.action = NewsAction.ANALYZE_URL
state.custom_urls = [query]
state = analyze_url(state)
elif action == "keywords":
state.action = NewsAction.EXTRACT_KEYWORDS
state = extract_keywords(state)
elif action == "report":
state.action = NewsAction.GENERATE_REPORT
state = generate_report(state)
else:
# 自动解析意图
state = parse_intent(state)
# 根据解析后的 action 调用
if state.action == NewsAction.QUERY_NEWS:
state = query_news(state)
elif state.action == NewsAction.ANALYZE_URL:
state.custom_urls = [query]
state = analyze_url(state)
elif state.action == NewsAction.EXTRACT_KEYWORDS:
state = extract_keywords(state)
elif state.action == NewsAction.GENERATE_REPORT:
state = generate_report(state)
# 格式化结果
state = format_result(state)
return {
"success": True,
"action": str(state.action),
"result": state.final_result,
"raw_data": {
"news_items": [vars(item) for item in state.news_items],
"extracted_keywords": state.extracted_keywords,
"report_content": state.report_content
}
}
except Exception as e:
return {"success": False, "error": str(e)}
@app.get("/subgraph/contact/{action}")
async def contact_subgraph_api(
action: str,
query: str = "",
user_id: str = "default"
):
"""通讯录子图简化 API"""
from backend.app.agent_subgraphs.contact import (
ContactState,
ContactAction,
parse_intent,
format_result
)
from backend.app.agent_subgraphs.contact.nodes import (
list_contacts, add_contact, list_emails, generate_email_draft, sniff_contacts
)
# 创建初始状态
state = ContactState(user_query=query, user_id=user_id)
# 处理 action
try:
if action == "list":
state.action = ContactAction.CONTACT_LIST
state = list_contacts(state)
elif action == "add":
state.action = ContactAction.CONTACT_ADD
state = add_contact(state)
elif action == "emails":
state.action = ContactAction.EMAIL_LIST
state = list_emails(state)
elif action == "draft":
state.action = ContactAction.EMAIL_SEND
state = generate_email_draft(state)
elif action == "sniff":
state.action = ContactAction.SNIFF_CONTACTS
state = sniff_contacts(state)
else:
# 自动解析意图
state = parse_intent(state)
# 根据解析后的 action 调用
if state.action == ContactAction.CONTACT_LIST:
state = list_contacts(state)
elif state.action == ContactAction.CONTACT_ADD:
state = add_contact(state)
elif state.action == ContactAction.EMAIL_LIST:
state = list_emails(state)
elif state.action == ContactAction.EMAIL_SEND:
state = generate_email_draft(state)
elif state.action == ContactAction.SNIFF_CONTACTS:
state = sniff_contacts(state)
# 格式化结果
state = format_result(state)
return {
"success": True,
"action": str(state.action),
"result": state.final_result,
"raw_data": {
"contacts": [vars(c) for c in state.contacts],
"emails": [vars(e) for e in state.emails],
"current_contact": vars(state.current_contact) if state.current_contact else None,
"draft": {
"subject": state.draft_subject,
"recipient": state.draft_recipient,
"body": state.draft_body
},
"sniffed": [vars(c) for c in state.sniffed_contacts]
}
}
except Exception as e:
return {"success": False, "error": str(e)}
@app.get("/subgraph/help")
async def subgraph_help_api():
"""子图 API 使用帮助"""
return {
"dictionary": {
"actions": ["query", "translate", "daily", "extract", "auto"],
"endpoint": "/subgraph/dictionary/{action}"
},
"news": {
"actions": ["query", "analyze", "keywords", "report", "auto"],
"endpoint": "/subgraph/news/{action}"
},
"contact": {
"actions": ["list", "add", "emails", "draft", "sniff", "auto"],
"endpoint": "/subgraph/contact/{action}"
}
}