修改引用逻辑,修改长期记忆bug

This commit is contained in:
2026-04-20 15:55:58 +08:00
parent 4e981e9dcf
commit 3143e0e4e6
39 changed files with 444 additions and 246 deletions

View File

@@ -16,7 +16,7 @@ from pydantic import BaseModel
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from app.agent import AIAgentService
from app.history import ThreadHistoryService
from app.logger import debug, info, warning, error
from app.logger import info, error
# 加载 .env 文件
load_dotenv()
@@ -28,7 +28,6 @@ DB_URI = os.getenv(
"postgresql://postgres:huang1998@ai-postgres:5432/langgraph_db?sslmode=disable"
)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用生命周期管理:创建并注入全局服务"""
@@ -53,7 +52,6 @@ async def lifespan(app: FastAPI):
# 5. 关闭时自动清理数据库连接async with 负责)
info("🛑 应用关闭,数据库连接池已释放")
app = FastAPI(lifespan=lifespan)
# CORS 中间件(允许前端跨域)
@@ -65,14 +63,12 @@ app.add_middleware(
allow_headers=["*"],
)
# ========== 健康检查端点 ==========
@app.get("/health")
async def health_check():
"""健康检查端点,用于 Docker 和 CI/CD 监控"""
return {"status": "ok", "service": "ai-agent-backend"}
# ========== Pydantic 模型 ==========
class ChatRequest(BaseModel):
message: str
@@ -80,7 +76,6 @@ class ChatRequest(BaseModel):
model: str = "zhipu"
user_id: str = "default_user"
class ChatResponse(BaseModel):
reply: str
thread_id: str
@@ -90,18 +85,15 @@ class ChatResponse(BaseModel):
total_tokens: int = 0
elapsed_time: float = 0.0
# ========== 依赖注入函数 ==========
def get_agent_service(request: Request) -> AIAgentService:
"""从 app.state 中获取全局 AIAgentService 实例"""
return request.app.state.agent_service
def get_history_service(request: Request) -> ThreadHistoryService:
"""从 app.state 中获取全局 ThreadHistoryService 实例"""
return request.app.state.history_service
# ========== HTTP 端点 ==========
@app.post("/chat", response_model=ChatResponse)
async def chat_endpoint(
@@ -135,7 +127,6 @@ async def chat_endpoint(
elapsed_time=elapsed_time
)
# ========== 历史查询接口 ==========
@app.get("/threads")
async def list_threads(
@@ -147,7 +138,6 @@ async def list_threads(
threads = await history_service.get_user_threads(user_id, limit)
return {"threads": threads}
@app.get("/thread/{thread_id}/messages")
async def get_thread_messages(
thread_id: str,
@@ -158,7 +148,6 @@ async def get_thread_messages(
messages = await history_service.get_thread_messages(thread_id)
return {"messages": messages}
@app.get("/thread/{thread_id}/summary")
async def get_thread_summary(
thread_id: str,
@@ -169,7 +158,6 @@ async def get_thread_summary(
summary = await history_service.get_thread_summary(thread_id)
return summary
# ========== 流式对话接口 ==========
@app.post("/chat/stream")
async def chat_stream_endpoint(
@@ -204,7 +192,6 @@ async def chat_stream_endpoint(
}
)
# ========== WebSocket 端点(可选) ==========
@app.websocket("/ws")
async def websocket_endpoint(
@@ -228,7 +215,6 @@ async def websocket_endpoint(
except WebSocketDisconnect:
pass
if __name__ == "__main__":
import uvicorn
# 使用环境变量或默认端口 8079避免与 llama.cpp 的 8081 端口冲突)