修改引用逻辑,修改长期记忆bug
This commit is contained in:
@@ -16,7 +16,7 @@ from pydantic import BaseModel
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
from app.agent import AIAgentService
|
||||
from app.history import ThreadHistoryService
|
||||
from app.logger import debug, info, warning, error
|
||||
from app.logger import info, error
|
||||
|
||||
# 加载 .env 文件
|
||||
load_dotenv()
|
||||
@@ -28,7 +28,6 @@ DB_URI = os.getenv(
|
||||
"postgresql://postgres:huang1998@ai-postgres:5432/langgraph_db?sslmode=disable"
|
||||
)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""应用生命周期管理:创建并注入全局服务"""
|
||||
@@ -53,7 +52,6 @@ async def lifespan(app: FastAPI):
|
||||
# 5. 关闭时自动清理数据库连接(async with 负责)
|
||||
info("🛑 应用关闭,数据库连接池已释放")
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
# CORS 中间件(允许前端跨域)
|
||||
@@ -65,14 +63,12 @@ app.add_middleware(
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
# ========== 健康检查端点 ==========
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""健康检查端点,用于 Docker 和 CI/CD 监控"""
|
||||
return {"status": "ok", "service": "ai-agent-backend"}
|
||||
|
||||
|
||||
# ========== Pydantic 模型 ==========
|
||||
class ChatRequest(BaseModel):
|
||||
message: str
|
||||
@@ -80,7 +76,6 @@ class ChatRequest(BaseModel):
|
||||
model: str = "zhipu"
|
||||
user_id: str = "default_user"
|
||||
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
reply: str
|
||||
thread_id: str
|
||||
@@ -90,18 +85,15 @@ class ChatResponse(BaseModel):
|
||||
total_tokens: int = 0
|
||||
elapsed_time: float = 0.0
|
||||
|
||||
|
||||
# ========== 依赖注入函数 ==========
|
||||
def get_agent_service(request: Request) -> AIAgentService:
|
||||
"""从 app.state 中获取全局 AIAgentService 实例"""
|
||||
return request.app.state.agent_service
|
||||
|
||||
|
||||
def get_history_service(request: Request) -> ThreadHistoryService:
|
||||
"""从 app.state 中获取全局 ThreadHistoryService 实例"""
|
||||
return request.app.state.history_service
|
||||
|
||||
|
||||
# ========== HTTP 端点 ==========
|
||||
@app.post("/chat", response_model=ChatResponse)
|
||||
async def chat_endpoint(
|
||||
@@ -135,7 +127,6 @@ async def chat_endpoint(
|
||||
elapsed_time=elapsed_time
|
||||
)
|
||||
|
||||
|
||||
# ========== 历史查询接口 ==========
|
||||
@app.get("/threads")
|
||||
async def list_threads(
|
||||
@@ -147,7 +138,6 @@ async def list_threads(
|
||||
threads = await history_service.get_user_threads(user_id, limit)
|
||||
return {"threads": threads}
|
||||
|
||||
|
||||
@app.get("/thread/{thread_id}/messages")
|
||||
async def get_thread_messages(
|
||||
thread_id: str,
|
||||
@@ -158,7 +148,6 @@ async def get_thread_messages(
|
||||
messages = await history_service.get_thread_messages(thread_id)
|
||||
return {"messages": messages}
|
||||
|
||||
|
||||
@app.get("/thread/{thread_id}/summary")
|
||||
async def get_thread_summary(
|
||||
thread_id: str,
|
||||
@@ -169,7 +158,6 @@ async def get_thread_summary(
|
||||
summary = await history_service.get_thread_summary(thread_id)
|
||||
return summary
|
||||
|
||||
|
||||
# ========== 流式对话接口 ==========
|
||||
@app.post("/chat/stream")
|
||||
async def chat_stream_endpoint(
|
||||
@@ -204,7 +192,6 @@ async def chat_stream_endpoint(
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# ========== WebSocket 端点(可选) ==========
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(
|
||||
@@ -228,7 +215,6 @@ async def websocket_endpoint(
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
# 使用环境变量或默认端口 8079(避免与 llama.cpp 的 8081 端口冲突)
|
||||
|
||||
Reference in New Issue
Block a user