diff --git a/.gitignore b/.gitignore index 9e8ca1d..4120caa 100644 --- a/.gitignore +++ b/.gitignore @@ -7,16 +7,14 @@ /* # 2. 放行需要的文件夹及其内容 -!app/ -!app/** +!backend/ +!backend/** !frontend/ !frontend/** !scripts/ !scripts/** !rag_indexer/ !rag_indexer/** -!rag_core/ -!rag_core/** !docker/ !docker/** !.gitea/ diff --git a/backend/app/__init__.py b/backend/app/__init__.py new file mode 100644 index 0000000..c32b8e6 --- /dev/null +++ b/backend/app/__init__.py @@ -0,0 +1,8 @@ +""" +AI Agent 应用模块 +""" + +from ..agent import AIAgentService +from ..graph.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME + +__all__ = ["AIAgentService", "AVAILABLE_TOOLS", "TOOLS_BY_NAME"] diff --git a/backend/app/agent/__init__.py b/backend/app/agent/__init__.py new file mode 100644 index 0000000..a5e2eba --- /dev/null +++ b/backend/app/agent/__init__.py @@ -0,0 +1,7 @@ +""" +Agent 子模块 +""" + +from .service import AIAgentService + +__all__ = ["AIAgentService"] diff --git a/backend/app/agent/history.py b/backend/app/agent/history.py new file mode 100644 index 0000000..e619da9 --- /dev/null +++ b/backend/app/agent/history.py @@ -0,0 +1,185 @@ +""" +历史对话查询模块 +利用 LangGraph 的 checkpointer 获取对话历史和摘要 +""" + +from typing import List, Dict, Any +from ..logger import error # 保持兼容,或者替换为 logger + +class ThreadHistoryService: + """线程历史查询服务""" + + def __init__(self, checkpointer): + self.checkpointer = checkpointer + + async def get_user_threads(self, user_id: str, limit: int = 50) -> List[Dict[str, Any]]: + """ + 获取指定用户的所有线程摘要信息 + + Args: + user_id: 用户 ID + limit: 返回数量限制 + + Returns: + 线程列表,每个包含 thread_id, last_updated, summary, message_count + """ + try: + # 查询 checkpoints 表获取用户的线程列表 + async with self.checkpointer.conn.cursor() as cur: + # 在较新的 LangGraph 版本中,AsyncPostgresSaver 创建的 checkpoints 表 + # 没有 created_at 列,而是使用 checkpoint_id 作为时间排序依据。 + # 我们可以直接按 thread_id 去重,并用 checkpoint_id 排序。 + # 另外,用户的 metadata 存储在 metadata JSONB 列中。 + query = """ + SELECT + thread_id, + MAX(checkpoint_id) as last_updated + FROM checkpoints + WHERE metadata->>'user_id' = %s + GROUP BY thread_id + ORDER BY last_updated DESC + LIMIT %s + """ + await cur.execute(query, (user_id, limit)) + rows = await cur.fetchall() + + threads = [] + for row in rows: + thread_id = row['thread_id'] + + # 获取该线程的状态 + state = await self.checkpointer.aget_tuple({"configurable": {"thread_id": thread_id}}) + + if state and hasattr(state, 'checkpoint') and isinstance(state.checkpoint, dict): + messages = state.checkpoint.get("channel_values", {}).get("messages", []) + + if messages: + summary = self._extract_summary(messages) + message_count = len([m for m in messages if hasattr(m, 'type') and m.type in ["human", "ai"]]) + + threads.append({ + "thread_id": thread_id, + # checkpoint_id 是一个类似于 uuid 的字符串,其中可能包含时间戳信息,也可以直接作为唯一标识 + "last_updated": row['last_updated'] if row['last_updated'] else "", + "summary": summary, + "message_count": message_count + }) + + return threads + + except Exception as e: + error(f"获取用户线程列表失败 (user_id={user_id}): {e}") + return [] + + async def get_thread_messages(self, thread_id: str) -> List[Dict[str, str]]: + """ + 获取指定线程的完整消息历史 + + Args: + thread_id: 线程 ID + + Returns: + 消息列表,格式 [{"role": "user/assistant", "content": "..."}] + """ + try: + state = await self.checkpointer.aget_tuple({"configurable": {"thread_id": thread_id}}) + + if state is None: + return [] + + messages = state.checkpoint.get("channel_values", {}).get("messages", []) if hasattr(state, 'checkpoint') and isinstance(state.checkpoint, dict) else [] + + if not messages: + return [] + + # 转换 LangChain 消息对象为字典 + result = [] + for msg in messages: + # 跳过 system 消息 + if hasattr(msg, 'type') and msg.type == "system": + continue + + if hasattr(msg, 'type'): + role = "user" if msg.type == "human" else "assistant" if msg.type == "ai" else msg.type + result.append({ + "role": role, + "content": msg.content + }) + elif isinstance(msg, dict): + role = msg.get("role", msg.get("type", "unknown")) + if role in ["human", "user"]: + role = "user" + elif role in ["ai", "assistant"]: + role = "assistant" + result.append({ + "role": role, + "content": msg.get("content", "") + }) + + return result + + except Exception as e: + error(f"获取线程消息历史失败: {e}") + return [] + + async def get_thread_summary(self, thread_id: str) -> Dict[str, Any]: + """ + 获取线程摘要(用于历史列表展示) + + Args: + thread_id: 线程 ID + + Returns: + 包含摘要信息的字典 + """ + try: + state = await self.checkpointer.aget_tuple({"configurable": {"thread_id": thread_id}}) + + if state is None or not state.values: + return {"thread_id": thread_id, "summary": "空对话", "message_count": 0} + + messages = state.values.get("messages", []) + summary = self._extract_summary(messages) + message_count = len([m for m in messages if hasattr(m, 'type') and m.type in ["human", "ai"]]) + + # 获取最后更新时间 + last_updated = "" + if state.metadata and "created_at" in state.metadata: + last_updated = state.metadata["created_at"].isoformat() + + return { + "thread_id": thread_id, + "summary": summary, + "message_count": message_count, + "last_updated": last_updated + } + + except Exception as e: + error(f"获取线程摘要失败: {e}") + return {"thread_id": thread_id, "summary": "加载失败", "message_count": 0} + + def _extract_summary(self, messages: List) -> str: + """ + 从消息列表中提取摘要 + + 策略: + 1. 如果有 summarize 节点生成的 summary,优先使用 + 2. 否则使用第一条用户消息的前 50 字 + """ + # 查找是否有 summary 字段 + for msg in messages: + if hasattr(msg, 'additional_kwargs') and msg.additional_kwargs.get('summary'): + return msg.additional_kwargs['summary'] + elif isinstance(msg, dict) and msg.get('summary'): + return msg['summary'] + + # 使用第一条用户消息作为摘要 + for msg in messages: + if hasattr(msg, 'type') and msg.type == "human": + content = msg.content + return content[:50] + "..." if len(content) > 50 else content + elif isinstance(msg, dict) and msg.get("role") in ["user", "human"]: + content = msg.get("content", "") + return content[:50] + "..." if len(content) > 50 else content + + return "空对话" \ No newline at end of file diff --git a/backend/app/agent/llm_factory.py b/backend/app/agent/llm_factory.py new file mode 100644 index 0000000..e0ca134 --- /dev/null +++ b/backend/app/agent/llm_factory.py @@ -0,0 +1,57 @@ +# app/llm_factory.py +import os +from ..config import ZHIPUAI_API_KEY, DEEPSEEK_API_KEY, VLLM_BASE_URL, LLAMACPP_API_KEY +from langchain_community.chat_models import ChatZhipuAI +from langchain_openai import ChatOpenAI +from pydantic import SecretStr + +class LLMFactory: + @staticmethod + def create_zhipu(): + api_key = ZHIPUAI_API_KEY + if not api_key: + raise ValueError("ZHIPUAI_API_KEY not set") + return ChatZhipuAI( + model="glm-4.7-flash", + api_key=api_key, + temperature=0.1, + max_tokens=4096, + timeout=120.0, + max_retries=3, + streaming=True, + ) + + @staticmethod + def create_deepseek(): + api_key = DEEPSEEK_API_KEY + if not api_key: + raise ValueError("DEEPSEEK_API_KEY not set") + return ChatOpenAI( + base_url="https://api.deepseek.com", + api_key=SecretStr(api_key), + model="deepseek-reasoner", + temperature=0.1, + max_tokens=4096, + timeout=60.0, + max_retries=2, + streaming=True, + ) + + @staticmethod + def create_local(): + base_url = VLLM_BASE_URL + return ChatOpenAI( + base_url=base_url, + api_key=SecretStr(LLAMACPP_API_KEY), + model="gemma-4-E4B-it", + timeout=60.0, + max_retries=2, + streaming=True, + ) + + # 模型创建器映射 + CREATORS = { + "local": create_local, + "deepseek": create_deepseek, + "zhipu": create_zhipu, + } \ No newline at end of file diff --git a/backend/app/agent/prompts.py b/backend/app/agent/prompts.py new file mode 100644 index 0000000..8b05050 --- /dev/null +++ b/backend/app/agent/prompts.py @@ -0,0 +1,37 @@ +# app/prompts.py +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder + +def create_system_prompt(tools: list = None) -> ChatPromptTemplate: + """ + 创建系统提示模板,可选择动态注入工具描述。 + """ + tools_section = "" + if tools: + tool_descs = [] + for tool in tools: + # 提取工具名称和描述的第一行 + name = getattr(tool, 'name', None) or getattr(tool, '__name__', 'unknown_tool') + desc = (tool.description or "").split('\n')[0] + tool_descs.append(f"- {name}: {desc}") + tools_section = "\n".join(tool_descs) + + system_template = ( + "你是一个个人生活助手和数据分析助手,请使用中文交流。\n\n" + "【用户背景信息】\n" + "以下是对当前用户的已知信息和长期记忆,你必须优先采纳并在回答中体现:\n" + "{memory_context}\n" + "若包含姓名、偏好等个人信息,请自然融入回应(例如称呼名字、提及偏好)。\n\n" + "【可用工具与使用规则】\n" + f"{tools_section}\n" + "工具调用时请直接返回所需参数,无需额外说明。\n\n" + "【回答要求(必须遵守)】\n" + "1. 回答必须简洁、直接。\n" + "2. 如果你认为该问题需要进行深入的推理或思考,请务必将你的思维链或推理过程用 `` 和 `` 标签包裹起来,放在回答的最前面。\n" + "3. 优先利用已知用户信息进行个性化回复。\n" + "4. 若无信息可依,礼貌询问或提供通用帮助。" + ) + + return ChatPromptTemplate.from_messages([ + ("system", system_template), + MessagesPlaceholder(variable_name="messages") + ]) \ No newline at end of file diff --git a/backend/app/agent/rag_initializer.py b/backend/app/agent/rag_initializer.py new file mode 100644 index 0000000..b637fc8 --- /dev/null +++ b/backend/app/agent/rag_initializer.py @@ -0,0 +1,23 @@ +# app/rag_initializer.py +from ..rag.tools import create_rag_tool_sync +from rag_core import create_parent_retriever +from ..logger import info, warning + +async def init_rag_tool(local_llm_creator): + """初始化 RAG 工具,失败返回 None""" + try: + info("🔄 正在初始化 RAG 检索系统...") + retriever = create_parent_retriever( + collection_name="rag_documents", + search_k=5, + ) + rewrite_llm = local_llm_creator() + rag_tool = create_rag_tool_sync( + retriever, rewrite_llm, + num_queries=3, rerank_top_n=5 + ) + info("✅ RAG 检索工具初始化成功") + return rag_tool + except Exception as e: + warning(f"⚠️ RAG 检索工具初始化失败: {e}") + return None \ No newline at end of file diff --git a/backend/app/agent/service.py b/backend/app/agent/service.py new file mode 100644 index 0000000..2146daf --- /dev/null +++ b/backend/app/agent/service.py @@ -0,0 +1,154 @@ +""" +AI Agent 服务类 - 支持多模型动态切换 +接收外部传入的 checkpointer,不负责管理连接生命周期 +""" + +import json + +# 本地模块 +from ..graph.graph_builder import GraphBuilder, GraphContext +from ..graph.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME +from .llm_factory import LLMFactory +from .rag_initializer import init_rag_tool +from ..logger import info, warning + +class AIAgentService: + def __init__(self, checkpointer): + self.checkpointer = checkpointer + self.graphs = {} + self.tools = AVAILABLE_TOOLS.copy() + self.tools_by_name = TOOLS_BY_NAME.copy() + + async def initialize(self): + # 1. 初始化 RAG 工具(如果需要) + rag_tool = await init_rag_tool(LLMFactory.create_local) + if rag_tool: + self.tools.append(rag_tool) + self.tools_by_name[rag_tool.name] = rag_tool + + # 2. 构建各模型的 Graph + for name, creator in LLMFactory.CREATORS.items(): + try: + info(f"🔄 初始化模型 '{name}'...") + llm = creator() + builder = GraphBuilder(llm, self.tools, self.tools_by_name).build() + graph = builder.compile(checkpointer=self.checkpointer) + self.graphs[name] = graph + info(f"✅ 模型 '{name}' 初始化成功") + except Exception as e: + warning(f"⚠️ 模型 '{name}' 初始化失败: {e}") + + if not self.graphs: + raise RuntimeError("没有可用的模型") + return self + + async def process_message(self, message: str, thread_id: str, model: str = "local", user_id: str = "default_user") -> dict: + """处理用户消息,返回包含回复、token统计和耗时的字典""" + if model not in self.graphs: + # 回退到第一个可用模型 + available = list(self.graphs.keys()) + if not available: + raise RuntimeError("没有可用的模型") + model = available[0] + warning(f"模型 '{model}' 不可用,已回退到 '{model}'") + + graph = self.graphs[model] + config = { + "configurable": {"thread_id": thread_id}, + "metadata": {"user_id": user_id} + } + input_state = {"messages": [{"role": "user", "content": message}]} + context = GraphContext(user_id=user_id) + + result = await graph.ainvoke(input_state, config=config, context=context) + + reply = result["messages"][-1].content + token_usage = result.get("last_token_usage", {}) + elapsed_time = result.get("last_elapsed_time", 0.0) + + return { + "reply": reply, + "token_usage": token_usage, + "elapsed_time": elapsed_time + } + + def _serialize_value(self, value): + """递归将 LangChain 对象转换为可 JSON 序列化的格式""" + if hasattr(value, 'content'): + msg_type = getattr(value, 'type', 'message') + return { + "role": msg_type, + "content": getattr(value, 'content', ''), + "additional_kwargs": getattr(value, 'additional_kwargs', {}), + "tool_calls": getattr(value, 'tool_calls', []) + } + elif isinstance(value, dict): + return {k: self._serialize_value(v) for k, v in value.items()} + elif isinstance(value, (list, tuple)): + return [self._serialize_value(item) for item in value] + else: + try: + json.dumps(value) + return value + except (TypeError, ValueError): + return str(value) + + async def process_message_stream(self, message: str, thread_id: str, model_name: str, user_id: str = "default_user"): + """流式处理消息,返回异步生成器""" + graph = self.graphs.get(model_name) + if not graph: + raise ValueError(f"模型 '{model_name}' 未找到或未初始化") + + config = { + "configurable": {"thread_id": thread_id}, + "metadata": {"user_id": user_id} + } + input_state = {"messages": [{"role": "user", "content": message}]} + context = GraphContext(user_id=user_id) + + async for chunk in graph.astream( + input_state, + config=config, + context=context, + stream_mode=["messages", "updates", "custom"], + version="v2", + subgraphs=True + ): + chunk_type = chunk["type"] + processed_event = {} + + if chunk_type == "messages": + message_chunk, metadata = chunk["data"] + node_name = metadata.get("langgraph_node", "unknown") + token_content = getattr(message_chunk, 'content', str(message_chunk)) + reasoning_token = "" + if hasattr(message_chunk, 'additional_kwargs'): + reasoning_token = message_chunk.additional_kwargs.get("reasoning_content", "") + + processed_event = { + "type": "llm_token", + "node": node_name, + "token": token_content, + "reasoning_token": reasoning_token, + "metadata": metadata + } + elif chunk_type == "updates": + updates_data = chunk["data"] + serialized_data = self._serialize_value(updates_data) + processed_event = { + "type": "state_update", + "data": serialized_data + } + if "messages" in serialized_data: + processed_event["messages"] = serialized_data["messages"] + elif chunk_type == "custom": + serialized_data = self._serialize_value(chunk["data"]) + processed_event = { + "type": "custom", + "data": serialized_data + } + else: + continue + + if processed_event: + yield processed_event \ No newline at end of file diff --git a/backend/app/backend.py b/backend/app/backend.py new file mode 100644 index 0000000..e7544b7 --- /dev/null +++ b/backend/app/backend.py @@ -0,0 +1,212 @@ +""" +FastAPI 后端 - 支持动态模型切换,使用 PostgreSQL 持久化记忆 +采用依赖注入模式,优雅管理资源生命周期 +""" + +import os +from .config import DB_URI, BACKEND_PORT +import uuid +import json +from contextlib import asynccontextmanager + +from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect, Depends, Request, Query +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import StreamingResponse +from pydantic import BaseModel +from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver +from .agent.service import AIAgentService +from .agent.history import ThreadHistoryService +from .logger import info, error + +@asynccontextmanager +async def lifespan(app: FastAPI): + """应用生命周期管理:创建并注入全局服务""" + # 1. 创建数据库连接池并初始化表(仅 checkpointer) + async with AsyncPostgresSaver.from_conn_string(DB_URI) as checkpointer: + await checkpointer.setup() + + # 2. 构建 AI Agent 服务 + agent_service = AIAgentService(checkpointer) + await agent_service.initialize() + + # 3. 创建历史查询服务 + history_service = ThreadHistoryService(checkpointer) + + # 4. 将服务实例存入 app.state + app.state.agent_service = agent_service + app.state.history_service = history_service + + # 应用运行中... + yield + + # 5. 关闭时自动清理数据库连接(async with 负责) + info("🛑 应用关闭,数据库连接池已释放") + +app = FastAPI(lifespan=lifespan) + +# CORS 中间件(允许前端跨域) +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# ========== 健康检查端点 ========== +@app.get("/health") +async def health_check(): + """健康检查端点,用于 Docker 和 CI/CD 监控""" + return {"status": "ok", "service": "ai-agent-backend"} + +# ========== Pydantic 模型 ========== +class ChatRequest(BaseModel): + message: str + thread_id: str | None = None + model: str = "zhipu" + user_id: str = "default_user" + +class ChatResponse(BaseModel): + reply: str + thread_id: str + model_used: str + input_tokens: int = 0 + output_tokens: int = 0 + total_tokens: int = 0 + elapsed_time: float = 0.0 + +# ========== 依赖注入函数 ========== +def get_agent_service(request: Request) -> AIAgentService: + """从 app.state 中获取全局 AIAgentService 实例""" + return request.app.state.agent_service + +def get_history_service(request: Request) -> ThreadHistoryService: + """从 app.state 中获取全局 ThreadHistoryService 实例""" + return request.app.state.history_service + +# ========== HTTP 端点 ========== +@app.post("/chat", response_model=ChatResponse) +async def chat_endpoint( + request: ChatRequest, + agent_service: AIAgentService = Depends(get_agent_service) +): + """同步对话接口,支持模型选择""" + if not request.message: + raise HTTPException(status_code=400, detail="message required") + + thread_id = request.thread_id or str(uuid.uuid4()) + result = await agent_service.process_message( + request.message, thread_id, request.model, request.user_id + ) + + # 提取 token 统计信息 + token_usage = result.get("token_usage", {}) + input_tokens = token_usage.get('prompt_tokens', token_usage.get('input_tokens', 0)) + output_tokens = token_usage.get('completion_tokens', token_usage.get('output_tokens', 0)) + elapsed_time = result.get("elapsed_time", 0.0) + + actual_model = request.model if request.model in agent_service.graphs else next(iter(agent_service.graphs.keys())) + + return ChatResponse( + reply=result["reply"], + thread_id=thread_id, + model_used=actual_model, + input_tokens=input_tokens, + output_tokens=output_tokens, + total_tokens=input_tokens + output_tokens, + elapsed_time=elapsed_time + ) + +# ========== 历史查询接口 ========== +@app.get("/threads") +async def list_threads( + user_id: str = Query("default_user", description="用户 ID"), + limit: int = Query(50, ge=1, le=200, description="返回数量限制"), + history_service: ThreadHistoryService = Depends(get_history_service) +): + """获取当前用户的对话历史列表""" + threads = await history_service.get_user_threads(user_id, limit) + return {"threads": threads} + +@app.get("/thread/{thread_id}/messages") +async def get_thread_messages( + thread_id: str, + user_id: str = Query("default_user", description="用户 ID"), + history_service: ThreadHistoryService = Depends(get_history_service) +): + """获取指定线程的完整消息历史""" + messages = await history_service.get_thread_messages(thread_id) + return {"messages": messages} + +@app.get("/thread/{thread_id}/summary") +async def get_thread_summary( + thread_id: str, + user_id: str = Query("default_user", description="用户 ID"), + history_service: ThreadHistoryService = Depends(get_history_service) +): + """获取指定线程的摘要信息""" + summary = await history_service.get_thread_summary(thread_id) + return summary + +# ========== 流式对话接口 ========== +@app.post("/chat/stream") +async def chat_stream_endpoint( + request: ChatRequest, + agent_service: AIAgentService = Depends(get_agent_service) +): + """流式对话接口(SSE)""" + if not request.message: + raise HTTPException(status_code=400, detail="message required") + + thread_id = request.thread_id or str(uuid.uuid4()) + + async def event_generator(): + try: + async for chunk in agent_service.process_message_stream( + request.message, thread_id, request.model, request.user_id + ): + yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n" + yield "data: [DONE]\n\n" + except Exception as e: + error(f"流式响应异常: {e}") + yield f"data: {json.dumps({'type': 'error', 'message': str(e)}, ensure_ascii=False)}\n\n" + yield "data: [DONE]\n\n" + + return StreamingResponse( + event_generator(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", # 禁用 Nginx 缓冲 + } + ) + +# ========== WebSocket 端点(可选) ========== +@app.websocket("/ws") +async def websocket_endpoint( + websocket: WebSocket, + agent_service: AIAgentService = Depends(get_agent_service) +): + await websocket.accept() + try: + while True: + data = await websocket.receive_json() + message = data.get("message") + thread_id = data.get("thread_id", str(uuid.uuid4())) + model = data.get("model", "zhipu") + user_id = data.get("user_id", "default_user") + if not message: + await websocket.send_json({"error": "missing message"}) + continue + reply = await agent_service.process_message(message, thread_id, model, user_id) + actual_model = model if model in agent_service.graphs else next(iter(agent_service.graphs.keys())) + await websocket.send_json({"reply": reply, "thread_id": thread_id, "model_used": actual_model}) + except WebSocketDisconnect: + pass + +if __name__ == "__main__": + import uvicorn + # 使用环境变量或默认端口 8079(避免与 llama.cpp 的 8081 端口冲突) + port = int(BACKEND_PORT) + uvicorn.run(app, host="0.0.0.0", port=port) diff --git a/backend/app/config.py b/backend/app/config.py new file mode 100644 index 0000000..5b3ea11 --- /dev/null +++ b/backend/app/config.py @@ -0,0 +1,50 @@ +""" +环境变量集中管理模块 +所有配置项统一定义,避免散落在各个文件中 +""" + +import os + + +# ========== Graph 执行追踪配置 ========== +# 是否启用 Graph 流转追踪(通过环境变量控制) +ENABLE_GRAPH_TRACE = os.getenv("ENABLE_GRAPH_TRACE", "true").lower() == "true" + +# ========== 记忆提取配置 ========== +# 记忆提取间隔:每 N 轮对话生成一次摘要 +MEMORY_SUMMARIZE_INTERVAL = int(os.getenv("MEMORY_SUMMARIZE_INTERVAL", "10")) + +# ========== Mem0 记忆层配置 ========== +# Qdrant 向量数据库地址 +QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333") +QDRANT_COLLECTION_NAME = os.getenv("QDRANT_COLLECTION_NAME", "mem0_user_memories") +QDRANT_API_KEY = os.getenv("QDRANT_API_KEY", "your-qdrant-api-key") + +# ========== llm 配置 ========== +# LLM 模型配置 +VLLM_BASE_URL = os.getenv("VLLM_BASE_URL", "http://127.0.0.1:8081/v1") +LLM_API_KEY = os.getenv("LLM_API_KEY", "your-ai-api-key") + +# llama.cpp Embedding 服务地址 (用于 Mem0 的向量化) +LLAMACPP_EMBEDDING_URL = os.getenv("LLAMACPP_EMBEDDING_URL", "http://127.0.0.1:8082/v1") +LLAMACPP_API_KEY = os.getenv("LLAMACPP_API_KEY", "your-llamacpp-api-key") + +# ========== 后端服务配置 ========== +# 数据库连接字符串 +DB_URI = os.getenv( + "DB_URI", + "postgresql://postgres:huang1998@ai-postgres:5432/langgraph_db?sslmode=disable" +) +# 后端服务端口 +BACKEND_PORT = int(os.getenv("BACKEND_PORT", "8079")) + +# ========== 日志配置 ========== +LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper() +DEBUG = os.getenv("DEBUG", "false").lower() == "true" + +# ========== Reranker 服务配置 ========== +LLAMACPP_RERANKER_URL = os.getenv("LLAMACPP_RERANKER_URL", "http://127.0.0.1:8083") + +# ========== 第三方 API 密钥 ========== +ZHIPUAI_API_KEY = os.getenv("ZHIPUAI_API_KEY", "") +DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY", "") diff --git a/backend/app/graph/__init__.py b/backend/app/graph/__init__.py new file mode 100644 index 0000000..a4c7afc --- /dev/null +++ b/backend/app/graph/__init__.py @@ -0,0 +1,8 @@ +""" +Graph 子模块 +""" + +from .graph_builder import GraphBuilder +from .state import MessagesState, GraphContext + +__all__ = ["GraphBuilder", "MessagesState", "GraphContext"] diff --git a/backend/app/graph/graph_builder.py b/backend/app/graph/graph_builder.py new file mode 100644 index 0000000..21abe76 --- /dev/null +++ b/backend/app/graph/graph_builder.py @@ -0,0 +1,83 @@ +""" +LangGraph 状态图构建模块 - 精简版,仅负责组装图 +所有节点逻辑已拆分到独立模块 +""" + +from langchain_core.language_models import BaseLLM +from langgraph.graph import StateGraph, START, END +from .state import MessagesState, GraphContext +from ..nodes import ( + should_continue, + create_llm_call_node, + create_tool_call_node, + create_retrieve_memory_node, + create_summarize_node, + finalize_node, +) +from ..nodes.memory_trigger import memory_trigger_node, set_mem0_client +from ..memory import Mem0Client + + +class GraphBuilder: + """LangGraph 状态图构建器 - 仅负责组装图""" + + def __init__(self, llm: BaseLLM, tools: list, tools_by_name: dict): + """ + 初始化构建器 + + Args: + llm: 大语言模型实例 + tools: 工具列表 + tools_by_name: 名称到工具函数的映射 + """ + self.llm = llm + self.tools = tools + self.tools_by_name = tools_by_name + + # ⭐ 创建 Mem0 客户端(懒加载,首次使用时初始化) + self.mem0_client = Mem0Client(llm) + + def build(self) -> StateGraph: + """ + 构建未编译的状态图 + + Returns: + StateGraph 实例 + """ + # 注入全局客户端 + set_mem0_client(self.mem0_client) + + builder = StateGraph(MessagesState, context_schema=GraphContext) + + # ⭐ 通过工厂函数创建节点(依赖注入) + retrieve_memory_node = create_retrieve_memory_node(self.mem0_client) + llm_call_node = create_llm_call_node(self.llm, self.tools) + tool_call_node = create_tool_call_node(self.tools_by_name) + summarize_node = create_summarize_node(self.mem0_client) + + # 添加节点 + builder.add_node("retrieve_memory", retrieve_memory_node) + builder.add_node("memory_trigger", memory_trigger_node) + builder.add_node("llm_call", llm_call_node) + builder.add_node("tool_node", tool_call_node) + builder.add_node("summarize", summarize_node) + builder.add_node("finalize", finalize_node) + + # 添加边 + builder.add_edge(START, "retrieve_memory") + builder.add_edge("retrieve_memory", "memory_trigger") + builder.add_edge("memory_trigger", "llm_call") + builder.add_conditional_edges( + "llm_call", + should_continue, + { + "tool_node": "tool_node", + "summarize": "summarize", + "finalize": "finalize" + } + ) + builder.add_edge("tool_node", "llm_call") + builder.add_edge("summarize", "finalize") + builder.add_edge("finalize", END) + + return builder \ No newline at end of file diff --git a/backend/app/graph/graph_tools.py b/backend/app/graph/graph_tools.py new file mode 100644 index 0000000..1cc1e17 --- /dev/null +++ b/backend/app/graph/graph_tools.py @@ -0,0 +1,95 @@ +""" +工具定义模块 - 纯函数工具,无依赖 AIAgent 类 +""" + +# 标准库 +from pathlib import Path + +# 第三方库 +import pandas as pd +import pypdf +import requests +from bs4 import BeautifulSoup +from langchain_core.tools import tool + +def _file_allow_check(filename: str) -> Path: + """检查用户文件名是否位于允许目录 './user_docs' 下,防止路径遍历攻击。""" + allowed_dir = Path("./user_docs").resolve() + allowed_dir.mkdir(exist_ok=True) + + file_path = (allowed_dir / filename).resolve() + if not str(file_path).startswith(str(allowed_dir)): + raise ValueError("错误:非法文件路径。") + + if not file_path.exists(): + raise FileNotFoundError(f"错误:文件 '{filename}' 不存在。") + + return file_path + +@tool +def get_current_temperature(location: str) -> str: + """获取指定地点的当前温度。""" + return f'当前{location}的温度为25℃' + +@tool +def read_local_file(filename: str) -> str: + """读取用户指定名称的本地文本文件内容并返回摘要。""" + try: + file_path = _file_allow_check(filename) + with open(file_path, 'r', encoding='utf-8') as f: + content = f.read() + return f"文件 '{filename}' 的内容开头:\n{content[:1000]}..." + except Exception as e: + return f"读取文件时出错:{str(e)}" + +@tool +def read_pdf_summary(filename: str) -> str: + """读取PDF文件并返回内容文本摘要。""" + try: + file_path = _file_allow_check(filename) + text = "" + with open(file_path, 'rb') as f: + reader = pypdf.PdfReader(f) + for page in reader.pages[:3]: + text += page.extract_text() + return f"PDF文件 '{filename}' 的前几页内容:\n{text[:2000]}..." + except Exception as e: + return f"读取PDF出错:{e}" + +@tool +def read_excel_as_markdown(filename: str) -> str: + """读取Excel文件,并将其主要数据转换为Markdown表格格式。""" + try: + file_path = _file_allow_check(filename) + df = pd.read_excel(file_path) + markdown_table = df.head(10).to_markdown(index=False) + return f"Excel文件 '{filename}' 的数据预览(前10行):\n{markdown_table}" + except Exception as e: + return f"读取Excel出错:{e}" + +@tool +def fetch_webpage_content(url: str) -> str: + """抓取给定URL的网页正文内容,并返回清晰的纯文本。""" + try: + response = requests.get(url, timeout=10) + response.raise_for_status() + soup = BeautifulSoup(response.text, 'html.parser') + for script in soup(["script", "style"]): + script.decompose() + text = soup.get_text() + lines = (line.strip() for line in text.splitlines()) + chunks = (phrase.strip() for line in lines for phrase in line.split(" ")) + text = '\n'.join(chunk for chunk in chunks if chunk) + return f"成功抓取网页 {url},正文内容开头:\n{text[:1500]}..." + except Exception as e: + return f"抓取网页时出错:{str(e)}" + +# 工具列表和映射(全局常量) +AVAILABLE_TOOLS = [ + get_current_temperature, + read_local_file, + fetch_webpage_content, + read_pdf_summary, + read_excel_as_markdown +] +TOOLS_BY_NAME = {tool.name: tool for tool in AVAILABLE_TOOLS} diff --git a/backend/app/graph/retrieve_memory.py b/backend/app/graph/retrieve_memory.py new file mode 100644 index 0000000..36453b0 --- /dev/null +++ b/backend/app/graph/retrieve_memory.py @@ -0,0 +1,76 @@ +""" +记忆检索节点模块 +负责从 Mem0 检索相关长期记忆 +""" + +from typing import Any, Dict + +# 本地模块 +from .state import MessagesState +from ..memory.mem0_client import Mem0Client +from ..utils.logging import log_state_change +from ..logger import debug + +def create_retrieve_memory_node(mem0_client: Mem0Client): + """ + 工厂函数:创建记忆检索节点 + + Args: + mem0_client: Mem0 客户端实例 + + Returns: + 异步节点函数 + """ + + from langchain_core.runnables.config import RunnableConfig + + async def retrieve_memory(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]: + """ + 记忆检索节点 - 使用 Mem0 + + Args: + state: 当前对话状态 + config: 运行时配置 + + Returns: + 包含 memory_context 的状态更新 + """ + log_state_change("retrieve_memory", state, "进入") + + # 从 metadata 中获取 user_id + user_id = config.get("metadata", {}).get("user_id", "default_user") + + # 兼容 dict 和对象两种消息格式 + last_msg = state["messages"][-1] + if isinstance(last_msg, dict): + query = str(last_msg.get("content", "")) + else: + query = str(last_msg.content) + memory_text_parts = [] + + # 确保 Mem0 已初始化(懒加载) + if not mem0_client._initialized: + await mem0_client.initialize() + + if mem0_client.mem0: + try: + # 异步调用 Mem0 语义检索 + facts = await mem0_client.search_memories(query, user_id=user_id, limit=5) + + if facts: + memory_text_parts.append(f"【相关长期记忆】\n" + "\n".join(f"- {f}" for f in facts)) + else: + debug("🔍 [记忆检索] 未找到相关记忆") + except Exception as e: + from app.logger import warning + warning(f"⚠️ Mem0 检索失败: {e}") + else: + from app.logger import warning + warning("⚠️ Mem0 未初始化,跳过记忆检索") + + memory_context = "\n\n".join(memory_text_parts) if memory_text_parts else "暂无用户信息" + result = {"memory_context": memory_context} + log_state_change("retrieve_memory", {**state, **result}, "离开") + return result + + return retrieve_memory diff --git a/backend/app/graph/state.py b/backend/app/graph/state.py new file mode 100644 index 0000000..2fd214e --- /dev/null +++ b/backend/app/graph/state.py @@ -0,0 +1,25 @@ +""" +LangGraph 状态定义模块 +包含 MessagesState 和 GraphContext +""" + +import operator +from typing import Annotated +from typing_extensions import TypedDict +from dataclasses import dataclass +from langchain_core.messages import AnyMessage + +class MessagesState(TypedDict): + """对话状态类型定义""" + messages: Annotated[list[AnyMessage], operator.add] + llm_calls: int + memory_context: str + last_token_usage: dict # 本次调用的 token 使用详情 + last_elapsed_time: float # 本次调用耗时(秒) + turns_since_last_summary: int # 距离上次生成摘要的轮数 + +@dataclass +class GraphContext: + """图执行上下文""" + user_id: str + # 可扩展更多上下文信息 diff --git a/backend/app/logger.py b/backend/app/logger.py new file mode 100644 index 0000000..777aeec --- /dev/null +++ b/backend/app/logger.py @@ -0,0 +1,56 @@ +""" +统一的日志模块 - 基于环境变量控制日志级别 +类似 C# 的条件编译效果,开发时打印详细调试信息,生产环境只输出关键信息 +""" + +import os +from .config import LOG_LEVEL, DEBUG +import logging +from typing import Any +from dotenv import load_dotenv + +# 先加载环境变量 +load_dotenv() + +# 从环境变量读取日志级别,默认 INFO + + +# 根据环境变量控制是否显示详细调试信息 +DEBUG_MODE = DEBUG + +# 创建统一的日志器 +logger = logging.getLogger("ai_agent") +logger.setLevel(getattr(logging, LOG_LEVEL, logging.INFO)) + +# 避免重复添加 handler +if not logger.handlers: + handler = logging.StreamHandler() + # 重要:handler 也需要设置级别,否则可能继承根 logger 的级别 + handler.setLevel(getattr(logging, LOG_LEVEL, logging.INFO)) + formatter = logging.Formatter( + fmt="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S" + ) + handler.setFormatter(formatter) + logger.addHandler(handler) + + +def debug(msg: Any, *args, **kwargs): + """调试日志,仅在 DEBUG 环境变量为 true 时打印""" + if DEBUG_MODE: + logger.debug(msg, *args, **kwargs) + + +def info(msg: Any, *args, **kwargs): + """信息日志""" + logger.info(msg, *args, **kwargs) + + +def warning(msg: Any, *args, **kwargs): + """警告日志""" + logger.warning(msg, *args, **kwargs) + + +def error(msg: Any, *args, **kwargs): + """错误日志""" + logger.error(msg, *args, **kwargs) diff --git a/backend/app/memory/__init__.py b/backend/app/memory/__init__.py new file mode 100644 index 0000000..ba1f389 --- /dev/null +++ b/backend/app/memory/__init__.py @@ -0,0 +1,7 @@ +""" +Mem0 记忆层模块 +""" + +from .mem0_client import Mem0Client + +__all__ = ["Mem0Client"] diff --git a/backend/app/memory/mem0_client.py b/backend/app/memory/mem0_client.py new file mode 100644 index 0000000..4a54ec2 --- /dev/null +++ b/backend/app/memory/mem0_client.py @@ -0,0 +1,146 @@ +from ..config import LLM_API_KEY +from ..config import VLLM_BASE_URL +import time +""" +Mem0 记忆层客户端封装模块 +负责 Mem0 的初始化、检索和存储 +""" + +import asyncio +from typing import Optional, List, Dict +from mem0 import AsyncMemory + +from ..config import ( + QDRANT_URL,QDRANT_COLLECTION_NAME,QDRANT_API_KEY, + VLLM_BASE_URL, LLM_API_KEY, + LLAMACPP_EMBEDDING_URL, LLAMACPP_API_KEY +) +from ..logger import info, warning, error + +class Mem0Client: + """Mem0 异步客户端封装类""" + + def __init__(self, llm_instance): + """ + 初始化 Mem0 客户端 + + Args: + llm_instance: LangChain LLM 实例(用于事实提取) + """ + self.llm = llm_instance + self.mem0: Optional[AsyncMemory] = None + self._initialized = False + + async def initialize(self): + """异步初始化 Mem0 客户端,并进行实际连接测试""" + if self._initialized: + return + + try: + # Mem0 配置 + config = { + "vector_store": { + "provider": "qdrant", + "config": { + "url": QDRANT_URL, # 直接使用完整 URL + "api_key": QDRANT_API_KEY, + "collection_name": QDRANT_COLLECTION_NAME, + "embedding_model_dims": 1024, + } + }, + "llm": { + "provider": "openai", + "config": { + "model": "LLM_MODEL", + "api_key": LLM_API_KEY, + "openai_base_url": VLLM_BASE_URL, + "temperature": 0.1, + "max_tokens": 2000, + } + }, + "embedder": { + "provider": "openai", + "config": { + "model": "Qwen3-Embedding-0.6B-Q8_0", + "api_key": LLAMACPP_API_KEY, + "openai_base_url": LLAMACPP_EMBEDDING_URL, + }, + }, + "version": "v1.1" + } + + self.mem0 = AsyncMemory.from_config(config) + info("✅ Mem0 配置加载成功,开始连接测试...") + + # 实际连接测试:调用一次 search 确保 Qdrant 和 Embedding 都可达 + await asyncio.wait_for( + self.mem0.search("ping", user_id="test", limit=1), + timeout=60.0 + ) + info("✅ Mem0 实际连接测试成功,初始化完成") + self._initialized = True + + except asyncio.TimeoutError: + error("❌ Mem0 连接测试超时 (10s),请检查 Qdrant 或 Embedding 服务响应") + self.mem0 = None + self._initialized = False + except Exception as e: + error(f"❌ Mem0 初始化或连接测试失败: {e}") + import traceback + error(f"详细错误信息:\n{traceback.format_exc()}") + self.mem0 = None + self._initialized = False + + async def search_memories(self, query: str, user_id: str, limit: int = 5) -> List[str]: + """ + 检索相关记忆 + + Args: + query: 查询文本 + user_id: 用户 ID + limit: 返回结果数量限制 + + Returns: + List[str]: 记忆事实列表 + """ + if not self.mem0: + warning("⚠️ Mem0 未初始化,跳过记忆检索") + return [] + + try: + memories = await asyncio.wait_for( + self.mem0.search(query, user_id=user_id, limit=limit), + timeout=30.0 + ) + + if memories and "results" in memories: + facts = [m["memory"] for m in memories["results"] if m.get("memory")] + if facts: + info(f"🔍 [记忆检索] Mem0 返回 {len(facts)} 条记忆") + return facts + + info("🔍 [记忆检索] 未找到相关记忆") + return [] + + except asyncio.TimeoutError: + warning("⚠️ Mem0 检索超时 (30s),跳过本次记忆检索") + return [] + except Exception as e: + warning(f"⚠️ Mem0 检索失败: {e}") + return [] + + async def add_memories(self, messages, user_id): + if not self.mem0: + return False + try: + start = time.time() + info(f"📝 开始 Mem0 add,消息数: {len(messages)}") + await asyncio.wait_for( + self.mem0.add(messages, user_id=user_id, metadata={"type": "conversation"}), + timeout=60.0 + ) + info(f"✅ Mem0 add 完成,耗时: {time.time() - start:.2f}s") + return True + except asyncio.TimeoutError: + error(f"❌ Mem0 记忆添加超时 (60s),已等待 {time.time() - start:.2f}s") + return False \ No newline at end of file diff --git a/backend/app/nodes/__init__.py b/backend/app/nodes/__init__.py new file mode 100644 index 0000000..345883a --- /dev/null +++ b/backend/app/nodes/__init__.py @@ -0,0 +1,19 @@ +""" +节点模块 - 导出所有 LangGraph 节点函数 +""" + +from .router import should_continue +from .llm_call import create_llm_call_node +from .tool_call import create_tool_call_node +from ..graph.retrieve_memory import create_retrieve_memory_node +from .summarize import create_summarize_node +from .finalize import finalize_node + +__all__ = [ + "should_continue", + "create_llm_call_node", + "create_tool_call_node", + "create_retrieve_memory_node", + "create_summarize_node", + "finalize_node", +] diff --git a/backend/app/nodes/finalize.py b/backend/app/nodes/finalize.py new file mode 100644 index 0000000..2203a0b --- /dev/null +++ b/backend/app/nodes/finalize.py @@ -0,0 +1,45 @@ +""" +完成事件节点模块 +负责发送完成事件,包含token使用情况和耗时信息 +""" + +from typing import Any, Dict +from langgraph.config import get_stream_writer + +# 本地模块 +from ..graph.state import MessagesState +from ..utils.logging import log_state_change +from ..logger import info, error + +from langchain_core.runnables.config import RunnableConfig + +async def finalize_node(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]: + """ + 完成事件节点 - 发送完成事件,包含token使用情况和耗时信息 + + Args: + state: 当前对话状态 + config: 运行时配置 + + Returns: + 空字典(完成节点,无状态更新) + """ + log_state_change("finalize", state, "进入") + + try: + # 获取流式写入器并发送完成事件 + writer = get_stream_writer() + writer({ + "type": "custom", + "data": { + "type": "done", + "token_usage": state.get("last_token_usage", {}), + "elapsed_time": state.get("last_elapsed_time", 0.0) + } + }) + info("🏁 [完成事件] 已发送完成事件,包含token使用情况和耗时信息") + except Exception as e: + error(f"❌ [完成事件] 发送完成事件时发生异常: {e}") + + log_state_change("finalize", state, "离开") + return {} \ No newline at end of file diff --git a/backend/app/nodes/llm_call.py b/backend/app/nodes/llm_call.py new file mode 100644 index 0000000..f413d85 --- /dev/null +++ b/backend/app/nodes/llm_call.py @@ -0,0 +1,150 @@ +""" +LLM 调用节点模块 +负责调用大语言模型并处理响应 +""" + +import time +from typing import Any, Dict +from langchain_core.language_models import BaseLLM +from langchain_core.messages import AIMessage + +# 本地模块 +from ..graph.state import MessagesState +from ..agent.prompts import create_system_prompt +from ..utils.logging import log_state_change +from ..logger import debug, info, error + +def create_llm_call_node(llm: BaseLLM, tools: list): + """ + 工厂函数:创建 LLM 调用节点 + + Args: + llm: LangChain LLM 实例 + tools: 工具列表 + + Returns: + 异步节点函数 + """ + # 构建调用链 + prompt = create_system_prompt(tools) + llm_with_tools = llm.bind_tools(tools) + + # 恢复带 RunnableLambda 的链,并在下方使用 astream 遍历 + chain = prompt | llm_with_tools + + from langchain_core.runnables.config import RunnableConfig + + async def call_llm(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]: + """ + LLM 调用节点(异步方法) + + Args: + state: 当前对话状态 + config: LangChain/LangGraph 自动注入的配置,包含 callbacks 等信息 + + Returns: + 更新后的状态字典 + """ + log_state_change("llm_call", state, "进入") + + memory_context = state.get("memory_context", "暂无用户信息") + start_time = time.time() + + try: + # 恢复为:手动进行 astream,并将所有的 chunk 拼接成最终的 response 返回。 + # LangGraph 会自动监听这期间产生的所有 token。 + chunks = [] + async for chunk in chain.astream( + { + "messages": state["messages"], + "memory_context": memory_context + }, + config=config + ): + chunks.append(chunk) + + # 将所有 chunk 合并成最终的 AIMessage + if chunks: + response = chunks[0] + for chunk in chunks[1:]: + response = response + chunk + else: + response = AIMessage(content="") + + elapsed_time = time.time() - start_time + + # 提取 token 用量(兼容不同 LLM 提供商的元数据格式) + token_usage = {} + input_tokens = 0 + output_tokens = 0 + + # 尝试从 response_metadata 中提取 + if hasattr(response, 'response_metadata') and response.response_metadata: + meta = response.response_metadata + if 'token_usage' in meta: + token_usage = meta['token_usage'] + elif 'usage' in meta: + token_usage = meta['usage'] + + # 尝试从 additional_kwargs 中提取 + if not token_usage and hasattr(response, 'additional_kwargs'): + add_kwargs = response.additional_kwargs + if 'llm_output' in add_kwargs and 'token_usage' in add_kwargs['llm_output']: + token_usage = add_kwargs['llm_output']['token_usage'] + + # 提取具体的 token 数值 + if token_usage: + input_tokens = token_usage.get('prompt_tokens', token_usage.get('input_tokens', 0)) + output_tokens = token_usage.get('completion_tokens', token_usage.get('output_tokens', 0)) + + # 打印 LLM 的完整输出 + debug("\n" + "="*80) + debug("📥 [LLM输出] 大模型返回的完整响应:") + debug(f" 消息类型: {response.type.upper()}") + debug(f" 内容长度: {len(str(response.content))} 字符") + debug("-"*80) + debug(f"{response.content}") + + # 打印响应统计信息 + info(f"⏱️ [LLM统计] 调用耗时: {elapsed_time:.2f}秒") + info(f"📊 [LLM统计] Token用量: 输入={input_tokens}, 输出={output_tokens}, 总计={input_tokens + output_tokens}") + if token_usage: + debug(f"📋 [LLM统计] 详细用量: {token_usage}") + debug("="*80 + "\n") + + result = { + "messages": [response], + "llm_calls": state.get('llm_calls', 0) + 1, + "last_token_usage": token_usage, + "last_elapsed_time": elapsed_time, + "turns_since_last_summary": state.get('turns_since_last_summary', 0) + 1 # 递增计数器 + } + + log_state_change("llm_call", {**state, **result}, "离开") + return result + + except Exception as e: + elapsed_time = time.time() - start_time + error(f"\n❌ [LLM错误] 调用失败 (耗时: {elapsed_time:.2f}秒)") + error(f" 错误类型: {type(e).__name__}") + error(f" 错误信息: {str(e)}") + import traceback + traceback.print_exc() + debug("="*80 + "\n") + + # 返回一个友好的错误消息 + error_response = AIMessage( + content="抱歉,模型暂时无法响应,可能是网络超时或服务繁忙,请稍后再试。" + ) + error_result = { + "messages": [error_response], + "llm_calls": state.get('llm_calls', 0), + "last_token_usage": {}, + "last_elapsed_time": elapsed_time, + "turns_since_last_summary": state.get('turns_since_last_summary', 0) + 1 # 即使出错也递增计数器 + } + + log_state_change("llm_call", state, "离开(异常)") + return error_result + + return call_llm diff --git a/backend/app/nodes/memory_trigger.py b/backend/app/nodes/memory_trigger.py new file mode 100644 index 0000000..be5a65c --- /dev/null +++ b/backend/app/nodes/memory_trigger.py @@ -0,0 +1,38 @@ +from typing import Any, Dict +from langchain_core.runnables.config import RunnableConfig +from ..graph.state import MessagesState +from ..memory.mem0_client import Mem0Client +from ..logger import info + +# 全局变量,在 GraphBuilder 中注入 +_mem0_client: Mem0Client = None + +def set_mem0_client(client: Mem0Client): + global _mem0_client + _mem0_client = client + +async def memory_trigger_node(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]: + """检测用户消息中的记忆指令,若命中则主动调用 Mem0 存储""" + if _mem0_client is None: + return {} + + messages = state.get("messages", []) + if not messages: + return {} + + last_msg = messages[-1] + content = last_msg.content if hasattr(last_msg, 'content') else str(last_msg) + + # 触发词(可自行扩展) + trigger_words = ["记住", "记下", "保存", "备忘", "记录下", "别忘了"] + if any(word in content for word in trigger_words): + user_id = config.get("metadata", {}).get("user_id", "default_user") + # 确保 Mem0 已初始化 + if not _mem0_client._initialized: + await _mem0_client.initialize() + # 将用户消息作为事实来源提交给 Mem0 + info(f"📌 检测到记忆指令,已主动触发 Mem0 存储") + mem0_messages = [{"role": "user", "content": content}] + await _mem0_client.add_memories(mem0_messages, user_id=user_id) + + return {} # 不修改状态 \ No newline at end of file diff --git a/backend/app/nodes/router.py b/backend/app/nodes/router.py new file mode 100644 index 0000000..6fab51a --- /dev/null +++ b/backend/app/nodes/router.py @@ -0,0 +1,48 @@ +""" +路由决策节点 +根据当前状态决定下一步走向 +""" + +from typing import Literal +from langchain_core.messages import AIMessage + +# 本地模块 +from ..config import ENABLE_GRAPH_TRACE, MEMORY_SUMMARIZE_INTERVAL +from ..graph.state import MessagesState +from ..logger import info + + +def should_continue(state: MessagesState) -> Literal['tool_node', 'summarize', 'finalize']: + """ + 决定下一步:工具调用、生成摘要还是结束 + + Args: + state: 当前对话状态 + + Returns: + 下一个节点名称 + """ + last_message = state["messages"][-1] + + # 1. 如果需要调用工具,优先进入工具节点 + if isinstance(last_message, AIMessage) and last_message.tool_calls: + if ENABLE_GRAPH_TRACE: + info(f"🔀 [路由决策] 检测到 {len(last_message.tool_calls)} 个工具调用 → 转向 'tool_node'") + return 'tool_node' + + # 2. 如果是 AI 的最终回复,判断是否达到摘要生成阈值 + if isinstance(last_message, AIMessage): + turns = state.get("turns_since_last_summary", 0) + if turns >= MEMORY_SUMMARIZE_INTERVAL: + if ENABLE_GRAPH_TRACE: + info(f"🔀 [路由决策] 收到 AI 最终回复,已达摘要阈值({turns}/{MEMORY_SUMMARIZE_INTERVAL}) → 转向 'summarize'") + return 'summarize' + else: + if ENABLE_GRAPH_TRACE: + info(f"🔀 [路由决策] 收到 AI 最终回复,未达摘要阈值({turns}/{MEMORY_SUMMARIZE_INTERVAL}) → 结束流程") + return 'finalize' + + # 3. 其他情况(如只有用户消息)直接结束 + if ENABLE_GRAPH_TRACE: + info(f"🔀 [路由决策] 非 AI 消息(如纯用户消息) → 结束流程") + return 'finalize' diff --git a/backend/app/nodes/summarize.py b/backend/app/nodes/summarize.py new file mode 100644 index 0000000..2c9856b --- /dev/null +++ b/backend/app/nodes/summarize.py @@ -0,0 +1,87 @@ +""" +记忆存储节点模块 +负责将对话历史提交给 Mem0 进行事实提取和存储 +""" + +from typing import Any, Dict + +# 本地模块 +from ..graph.state import MessagesState +from ..memory.mem0_client import Mem0Client +from ..utils.logging import log_state_change +from ..logger import debug, info, error, warning + +def create_summarize_node(mem0_client: Mem0Client): + """ + 工厂函数:创建记忆存储节点 + + Args: + mem0_client: Mem0 客户端实例 + + Returns: + 异步节点函数 + """ + + from langchain_core.runnables.config import RunnableConfig + + async def summarize_conversation(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]: + """ + 记忆存储节点 - 使用 Mem0 + + Args: + state: 当前对话状态 + config: 运行时配置 + + Returns: + 重置计数器的状态更新 + """ + log_state_change("summarize", state, "进入") + + messages = state["messages"] + if len(messages) < 4: + debug("📝 [记忆添加] 对话过短,跳过") + return {"turns_since_last_summary": 0} + + # 从 metadata 中获取 user_id + user_id = config.get("metadata", {}).get("user_id", "default_user") + + # 确保 Mem0 已初始化(懒加载) + if not mem0_client._initialized: + await mem0_client.initialize() + + # 将整个对话历史转换为 Mem0 需要的消息格式 + mem0_messages = [] + for msg in messages: + # 兼容 dict 和对象两种格式 + if isinstance(msg, dict): + msg_type = msg.get("type", "") + msg_content = msg.get("content", "") + else: + msg_type = getattr(msg, 'type', '') + msg_content = getattr(msg, 'content', '') + + if msg_type == "human": + mem0_messages.append({"role": "user", "content": msg_content}) + elif msg_type == "ai": + mem0_messages.append({"role": "assistant", "content": msg_content}) + elif msg_type == "tool": + mem0_messages.append({"role": "system", "content": f"[工具返回] {msg_content}"}) + + if mem0_client.mem0: + try: + # 异步调用 Mem0 自动提取并存储事实 + success = await mem0_client.add_memories( + mem0_messages, + user_id=user_id + ) + if success: + info(f"📝 [记忆添加] 已提交给 Mem0 进行事实提取") + except Exception as e: + error(f"❌ Mem0 记忆添加失败: {e}") + else: + warning("⚠️ Mem0 未初始化,跳过记忆添加") + + log_state_change("summarize", state, "离开") + return {"turns_since_last_summary": 0} + + return summarize_conversation \ No newline at end of file diff --git a/backend/app/nodes/tool_call.py b/backend/app/nodes/tool_call.py new file mode 100644 index 0000000..69ad2e6 --- /dev/null +++ b/backend/app/nodes/tool_call.py @@ -0,0 +1,101 @@ +""" +工具执行节点模块 +负责执行 AI 调用的工具函数 +""" + +import asyncio +from typing import Any, Dict +from langchain_core.messages import AIMessage, ToolMessage +from langgraph.config import get_stream_writer + +# 本地模块 +from ..graph.state import MessagesState +from ..utils.logging import log_state_change +from ..logger import debug, info + +def create_tool_call_node(tools_by_name: Dict[str, Any]): + """ + 工厂函数:创建工具执行节点 + + Args: + tools_by_name: 名称到工具函数的映射字典 + + Returns: + 异步节点函数 + """ + + from langchain_core.runnables.config import RunnableConfig + + async def call_tools(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]: + """ + 工具执行节点(异步方法) + + Args: + state: 当前对话状态 + config: 运行时配置 + + Returns: + 包含 ToolMessage 的状态更新 + """ + log_state_change("tool_node", state, "进入") + + last_message = state['messages'][-1] + if not isinstance(last_message, AIMessage) or not last_message.tool_calls: + log_state_change("tool_node", state, "离开(无工具调用)") + return {"messages": []} + + results = [] + loop = asyncio.get_event_loop() + + info(f"🛠️ [工具调用] 准备执行 {len(last_message.tool_calls)} 个工具") + + for tool_call in last_message.tool_calls: + tool_name = tool_call["name"] + tool_args = tool_call["args"] + tool_id = tool_call["id"] + tool_func = tools_by_name.get(tool_name) + + debug(f" ├─ 调用工具: {tool_name} 参数: {tool_args}") + + if tool_func is None: + err_msg = f"Tool {tool_name} not found" + debug(f" └─ ❌ {err_msg}") + results.append(ToolMessage(content=err_msg, tool_call_id=tool_id)) + continue + + # 获取流式写入器并发送工具调用开始事件 + writer = get_stream_writer() + writer({"type": "custom", "data": {"type": "tool_start", "tool": tool_name}}) + + try: + # 修复闭包问题:将变量作为默认参数传入 lambda + # 如果工具支持异步 (ainvoke),优先使用异步调用 + if hasattr(tool_func, 'ainvoke'): + observation = await tool_func.ainvoke(tool_args) + else: + observation = await loop.run_in_executor( + None, + lambda args=tool_args: tool_func.invoke(args) # 默认参数捕获当前值 + ) + + # 字符打印 + result_preview = str(observation).replace("\n", " ") + debug(f" └─ ✅ 结果: {result_preview}") + results.append(ToolMessage(content=str(observation), tool_call_id=tool_id)) + + # 发送工具调用完成事件 + writer({"type": "custom", "data": {"type": "tool_end", "tool": tool_name, "success": True}}) + except Exception as e: + debug(f" └─ ❌ 异常: {e}") + results.append(ToolMessage(content=f"Error: {e}", tool_call_id=tool_id)) + + # 发送工具调用失败事件 + writer({"type": "custom", "data": {"type": "tool_end", "tool": tool_name, "success": False, "error": str(e)}}) + + info(f"🛠️ [工具调用] 执行完成,返回 {len(results)} 条 ToolMessage") + + result = {"messages": results} + log_state_change("tool_node", {**state, **result}, "离开") + return result + + return call_tools \ No newline at end of file diff --git a/backend/app/rag/README.md b/backend/app/rag/README.md new file mode 100644 index 0000000..a91d8f6 --- /dev/null +++ b/backend/app/rag/README.md @@ -0,0 +1,391 @@ +# 在线 RAG 检索与生成系统 (Online RAG Retriever) + +该模块负责 RAG 系统的阶段二:**在线检索与生成**。它接收用户提问,从知识库中检索出上下文,利用各种高级策略去噪、融合,并作为增强上下文输入给大语言模型 (LLM)。 + +## 🎯 核心架构 + +### 技术栈 + +| 组件 | 技术选型 | 版本 | 说明 | +|:-----|:---------|:-----|:-----| +| **基础检索** | `Qdrant` | 1.17+ | HNSW 稠密向量检索 | +| **混合检索** | `Qdrant` + `BM25` | 内置 | 稠密 + 稀疏向量融合 | +| **查询改写** | `LangChain` | 内置 | `MultiQueryGenerator` 多路改写 | +| **RRF 融合** | 自实现 | - | `reciprocal_rank_fusion` 倒数排名融合 | +| **重排序** | `llama.cpp` | 本地服务 | OpenAI 兼容 Rerank API | +| **编排框架** | `asyncio` | Python 3.10+ | 异步并行检索 | + +### 检索流水线 + +``` +┌─────────────────────────────────────────────────────────────┐ +│ 用户提问 │ +└──────────────────────┬──────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ MultiQueryGenerator │ +│ 多路查询改写 (num_queries=3) │ +│ "如何申请项目资金?" → ["项目资金申请流程", "经费申请步骤"] │ +└──────────────────────┬──────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ 并行检索 (asyncio.gather) │ +│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │ +│ │ 查询1 检索 │ │ 查询2 检索 │ │ 查询3 检索 │ │ +│ │ (k=20) │ │ (k=20) │ │ (k=20) │ │ +│ └──────────────┘ └──────────────┘ └──────────────┘ │ +└──────────────────────┬──────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ reciprocal_rank_fusion (RRF) │ +│ RRF_score(d) = Σ 1/(k + rank_q(d)) (k=60) │ +│ 融合多路检索结果,去重排序 │ +└──────────────────────┬──────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ LLaMaCPPReranker │ +│ 远程重排序 (bge-reranker-v2-m3) │ +│ 返回 Top-N (top_n=5) 最相关文档 │ +└──────────────────────┬──────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ 返回增强上下文 │ +│ format_context() → 格式化输出 │ +└─────────────────────────────────────────────────────────────┘ +``` + +### 技术特性 + +- ✅ **多路查询改写**:通过 LLM 将单一问题改写为多个不同角度的查询 +- ✅ **RRF 融合算法**:Reciprocal Rank Fusion,无需评分归一化的融合算法 +- ✅ **远程重排序**:使用 llama.cpp 服务的 OpenAI 兼容 Rerank API +- ✅ **混合检索支持**:稠密向量 + BM25 稀疏向量混合检索 +- ✅ **异步并行检索**:多路查询并行执行,提升检索速度 +- ✅ **优雅降级**:重排序器不可用时自动降级到基础融合结果 + +## 📂 架构与文件结构 + +``` +app/rag/ +├── __init__.py +├── retriever.py # Qdrant 基础检索与混合检索 +├── reranker.py # llama.cpp 远程重排序器 +├── query_transform.py # 多路查询改写生成器 +├── fusion.py # RRF 倒数排名融合算法 +├── pipeline.py # RAG 流水线编排 +└── tools.py # LangChain Tool 封装 +``` + +## 🎯 演进路线与算法详解 (Roadmap) + +### Level 1: 基础向量搜索 (Basic Similarity Search) + +- **核心算法**: 近似最近邻搜索 (ANN, 常用 HNSW 算法)。将用户问题转化为向量后,计算它与库中向量的余弦相似度 (Cosine Similarity),取距离最近的 K 个块。 +- **优缺点**: 速度极快。但只能捕捉"语义相似",如果用户搜索特定专有名词、编号、订单号,纯向量检索往往会失效(产生"幻觉"匹配)。 +- **实现指南**: + - 使用 `rag_indexer.embedders.LlamaCppEmbedder` 作为嵌入模型 + - 使用 `app/rag/retriever.py` 中的 `create_base_retriever` 创建基础检索器 + - 配置 `search_kwargs={"k": 20}` 进行初步召回 + +```python +from app.rag.retriever import create_base_retriever + +retriever = create_base_retriever( + collection_name="rag_documents", + embeddings=embeddings, + search_kwargs={"k": 20} +) +docs = retriever.invoke("什么是 RAG?") +``` + +### Level 2: 混合检索与重排序 (Hybrid Search + Reranker) + +混合检索旨在结合向量的"语义泛化"与关键词的"精准匹配",随后利用重排序模型过滤噪声。 + +**1. 基础召回 (混合检索)** + +- **核心原理**: 结合基于 HNSW 的 Dense Vector 相似度搜索与基于 TF-IDF 的 BM25 稀疏检索 (Sparse Vector)。 +- **实现指南**: 使用 `app/rag/retriever.py` 中的 `create_hybrid_retriever` 函数,配置 `dense_k=10` 和 `sparse_k=10`,总召回 20 条结果。 + +```python +from app.rag.retriever import create_hybrid_retriever + +retriever = create_hybrid_retriever( + collection_name="rag_documents", + embeddings=embeddings, + dense_k=10, + sparse_k=10, + score_threshold=0.3 +) +``` + +**2. 二次精排 (Cross-Encoder)** + +- **核心原理**: 不同于双塔模型(分别算向量再求距离),交叉编码器将"用户问题 + 检索到的单例文档"拼接后整体输入 Transformer 模型,由模型直接输出 0~1 的相关性得分,精度极高。 +- **实现指南**: + - 使用 `app/rag/reranker.py` 中的 `LLaMaCPPReranker` 类,加载 `bge-reranker-v2-m3` 模型 + - 设置 `top_n=5` 保留最相关的 5 条结果 + +```python +from app.rag.reranker import LLaMaCPPReranker + +reranker = LLaMaCPPReranker( + base_url="http://127.0.0.1:8083", + api_key="your-api-key", + top_n=5 +) +sorted_docs = reranker.compress_documents(documents, query) +``` + +### Level 3: RAG-Fusion (多路改写与倒数排名融合) + +RAG-Fusion 通过大模型发散思维,将单一问题改写为多个相似问题,扩大搜索面,再利用数学统计算法合并结果。 + +**1. 多路查询改写** + +- **核心原理**: 克服用户初始提问词不达意或视角受限的问题。 +- **实现指南**: 使用 `app/rag/query_transform.py` 中的 `MultiQueryGenerator` 类,配置 `num_queries=3` 生成 3 个不同角度的查询。 + +```python +from app.rag.query_transform import MultiQueryGenerator + +generator = MultiQueryGenerator(llm=llm, num_queries=3) +queries = await generator.agenerate("如何申请项目资金?") +# 返回:["如何申请项目资金?", "项目资金申请流程是什么?", "申请项目经费需要哪些步骤?"] +``` + +**2. 倒数排名融合 (RRF)** + +- **核心原理**: RRF (Reciprocal Rank Fusion) 是一种无需评分归一化的融合算法。公式为 `RRF_score(d) = Σ 1/(k + rank_q(d))`,有效避免某一极端检索结果主导全局。 +- **实现指南**: 使用 `app/rag/fusion.py` 中的 `reciprocal_rank_fusion` 函数,配置 `k=60` 实现倒数排名融合。 + +```python +from app.rag.fusion import reciprocal_rank_fusion + +# 多个查询的检索结果 +doc_lists = [result1, result2, result3] +fused_docs = reciprocal_rank_fusion(doc_lists, k=60) +``` + +### Level 4: Agentic RAG / Self-RAG (智能体与自我反思) + +- **核心原理**: 基于 LangGraph 的 ReAct (Reasoning and Acting) 状态机路由。大模型并非每次都去死板地执行检索,而是先判断问题:"这是闲聊?还是需要查知识库?"。如果是后者,模型输出一个 `ToolCall` 指令,触发检索。 +- **实现指南**: 使用 `app/rag/tools.py` 中的 `search_knowledge_base` 工具,将其绑定到 LangGraph 状态机中。 + +- **示意图**: + +``` +┌──────────┐ ┌──────────────┐ ┌──────────┐ ┌──────── +│ User │────>│ LangGraph │────>│ RAG_Tool │────>│ Qdrant │ +│ │ │ Agent │ │ │ │ │ +│ "公司报 │ │ 思考: 这是 │ │ ToolCall │ │ RAG- │ +│ 销流程?"│ │ 内部规章问题 │ │ search_ │ │ Fusion │ +│ │ │ 需要查资料 │ │ knowledge│ │ & 混合 │ +│ │<────│ 资料充分, │<────│ 返回最相 │<────│ 检索 │ +│ "根据知 │ │ 开始撰写回答 │ │ 关5条规定 │ │ Cross- │ +│ 识库规定 │ │ │ │ │ │ Encoder│ +│ ..." │ │ │ │ │ │ 重排 │ +└────────── └────────────── └──────────┘ └────────┘ +``` + +### Level 5: GraphRAG 集成 (基于图和关系的 RAG) + +- **核心原理**: 结合知识图谱的结构化关系和向量检索的语义相似度,解决跨文档复杂关系推理问题。 +- **实现指南**: + - 使用 `langchain_community.graphs` 模块构建知识图谱 + - 配置本地大模型(如 `Gemma-4-E4B`)用于实体关系抽取 + - 实现混合检索逻辑,结合向量相似度和图路径分析 + +```python +from langchain_community.graphs import Neo4jGraph +from langchain_experimental.graph_transformers import LLMGraphTransformer + +# 实体关系抽取 +transformer = LLMGraphTransformer(llm=local_llm) +graph_documents = transformer.convert_to_graph_documents(documents) + +# 存储到图数据库 +graph = Neo4jGraph(url="bolt://localhost:7687") +graph.add_graph_documents(graph_documents) +``` + +## 🔧 核心组件详解 + +### 1. 检索器 (retriever.py) + +提供基于 Qdrant 的向量检索能力。 + +**基础检索器**: +```python +from app.rag.retriever import create_base_retriever + +retriever = create_base_retriever( + collection_name="rag_documents", + embeddings=embeddings, + search_kwargs={"k": 20} +) +``` + +**混合检索器**: +```python +from app.rag.retriever import create_hybrid_retriever + +retriever = create_hybrid_retriever( + collection_name="rag_documents", + embeddings=embeddings, + dense_k=10, + sparse_k=10, + score_threshold=0.3 +) +``` + +### 2. 多路查询改写 (query_transform.py) + +通过 LLM 将用户问题改写为多个不同版本,扩大搜索面。 + +```python +from app.rag.query_transform import MultiQueryGenerator + +generator = MultiQueryGenerator(llm=llm, num_queries=3) +queries = await generator.agenerate("如何申请项目资金?") +``` + +### 3. RRF 融合算法 (fusion.py) + +Reciprocal Rank Fusion 算法,公式:`RRF_score(d) = Σ 1/(k + rank_q(d))` + +```python +from app.rag.fusion import reciprocal_rank_fusion + +# 多个查询的检索结果 +doc_lists = [result1, result2, result3] +fused_docs = reciprocal_rank_fusion(doc_lists, k=60) +``` + +### 4. 重排序器 (reranker.py) + +使用 llama.cpp 服务的 OpenAI 兼容 Rerank API 对检索结果重排序。 + +```python +from app.rag.reranker import LLaMaCPPReranker + +reranker = LLaMaCPPReranker( + base_url="http://127.0.0.1:8083", + api_key="your-api-key", + top_n=5 +) +sorted_docs = reranker.compress_documents(documents, query) +``` + +### 5. RAG 流水线 (pipeline.py) + +组合上述组件的完整检索流水线。 + +```python +from app.rag.pipeline import RAGPipeline + +pipeline = RAGPipeline( + retriever=retriever, + llm=llm, + num_queries=3, + rerank_top_n=5, +) + +# 异步检索 +docs = await pipeline.aretrieve("如何申请项目资金?") + +# 格式化上下文 +context = pipeline.format_context(docs) +``` + +## 🔄 与 Agent 系统集成 + +### 封装为 LangChain Tool + +```python +from langchain_core.tools import tool +from app.rag.pipeline import RAGPipeline + +@tool +def search_knowledge_base(query: str) -> str: + """搜索知识库获取相关信息""" + docs = pipeline.retrieve(query) + return pipeline.format_context(docs) +``` + +### 绑定到 LangGraph + +```python +from app.graph.graph_builder import GraphBuilder + +# 将 RAG 工具添加到工具列表 +tools = AVAILABLE_TOOLS + [search_knowledge_base] + +# 构建图 +builder = GraphBuilder(llm, tools, tools_by_name) +graph = builder.build().compile(checkpointer=checkpointer) +``` + +## ⚙️ 环境配置 + +| 变量名 | 说明 | 默认值 | +|:-------|:-----|:-------| +| `QDRANT_URL` | Qdrant 向量数据库地址 | `http://127.0.0.1:6333` | +| `QDRANT_API_KEY` | Qdrant API 密钥 | - | +| `LLAMACPP_RERANKER_URL` | llama.cpp 重排序服务地址 | `http://127.0.0.1:8083` | +| `LLAMACPP_API_KEY` | llama.cpp API 密钥 | - | + +## 🚀 快速开始 + +```python +# 1. 初始化嵌入模型 +from rag_core.embedders import LlamaCppEmbedder +embedder = LlamaCppEmbedder() +embeddings = embedder.as_langchain_embeddings() + +# 2. 创建检索器 +from app.rag.retriever import create_base_retriever +retriever = create_base_retriever( + collection_name="rag_documents", + embeddings=embeddings, + search_kwargs={"k": 20} +) + +# 3. 创建 RAG 流水线 +from app.rag.pipeline import RAGPipeline +pipeline = RAGPipeline( + retriever=retriever, + llm=llm, + num_queries=3, + rerank_top_n=5, +) + +# 4. 执行检索 +docs = pipeline.retrieve("如何申请项目资金?") + +# 5. 格式化上下文 +context = pipeline.format_context(docs) +print(context) +``` + +## 📊 检索策略对比 + +| 策略 | 优点 | 缺点 | 适用场景 | +|:-----|:-----|:-----|:---------| +| **基础向量检索** | 速度快,语义理解好 | 专有名词匹配差 | 通用问答 | +| **混合检索** | 语义 + 关键词匹配 | 需要配置稀疏向量 | 专业术语查询 | +| **多路改写 + RRF** | 搜索面广,结果稳定 | 延迟略高 | 复杂问题 | +| **重排序** | 精度高 | 依赖额外模型 | 最终精排 | + +## 🤝 与 rag_indexer 集成 + +- **向量存储**:共享 Qdrant 集合,确保嵌入模型一致 +- **文档存储**:使用 PostgreSQL 存储父块,通过 UUID 映射 +- **集合名称**:默认使用 `rag_documents` 集合 + +详见 [rag_indexer/README.md](../../rag_indexer/README.md) diff --git a/backend/app/rag/__init__.py b/backend/app/rag/__init__.py new file mode 100644 index 0000000..4f86e0b --- /dev/null +++ b/backend/app/rag/__init__.py @@ -0,0 +1,69 @@ +""" +RAG 检索与生成模块 + +提供在线检索与生成功能,包括: +- 基础向量检索(稠密向量 / 混合检索) +- 重排序(Cross-Encoder) +- 多路查询改写(Multi-Query) +- RRF 融合(Reciprocal Rank Fusion) +- 完整的 RAG 流水线 +- Agent 工具封装 + +固定流水线: + 用户查询 → 多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档 + +示例用法: + >>> from app.rag.rag import RAGPipeline, create_rag_tool + >>> from rag_indexer.builder import IndexBuilder, IndexBuilderConfig + >>> from langchain_openai import ChatOpenAI + >>> + >>> # 获取基础检索器(如父子块检索器) + >>> config = IndexBuilderConfig(collection_name="my_docs") + >>> builder = IndexBuilder(config) + >>> retriever = builder.retriever + >>> + >>> # 创建 LLM 和流水线 + >>> llm = ChatOpenAI(model="gpt-3.5-turbo") + >>> pipeline = RAGPipeline(retriever=retriever, llm=llm) + >>> + >>> # 检索 + >>> docs = await pipeline.aretrieve("什么是 RAG?") + >>> context = pipeline.format_context(docs) + >>> + >>> # 创建 Agent 工具 + >>> rag_tool = create_rag_tool(retriever=retriever, llm=llm) +""" + +from .retriever import ( + create_base_retriever, + create_hybrid_retriever, + create_qdrant_client, +) +from .reranker import LLaMaCPPReranker +from .query_transform import MultiQueryGenerator +from .fusion import reciprocal_rank_fusion +from .pipeline import RAGPipeline +from .tools import create_rag_tool_sync + + +__all__ = [ + # 检索器工厂函数 + "create_base_retriever", + "create_hybrid_retriever", + "create_qdrant_client", + + # 重排序器 + "LLaMaCPPReranker", + + # 查询改写生成器 + "MultiQueryGenerator", + + # 融合算法 + "reciprocal_rank_fusion", + + # 主流水线 + "RAGPipeline", + + # 工具创建(供 Agent 使用) + "create_rag_tool_sync", +] \ No newline at end of file diff --git a/backend/app/rag/fusion.py b/backend/app/rag/fusion.py new file mode 100644 index 0000000..ddf8f42 --- /dev/null +++ b/backend/app/rag/fusion.py @@ -0,0 +1,36 @@ +# rag/fusion.py + +from typing import List, Dict +from langchain_core.documents import Document + +def reciprocal_rank_fusion( + doc_lists: List[List[Document]], + k: int = 60 +) -> List[Document]: + """ + 对多个检索结果列表进行 RRF 融合。 + + Args: + doc_lists: 多个检索结果列表,每个列表来自一个查询 + k: RRF 常数,通常设为 60 + + Returns: + 融合后按 RRF 得分降序排列的文档列表 + """ + # 使用文档内容作为唯一标识(如果内容相同但 metadata 不同,视为同一文档) + # 更好的做法是用 docstore 的 ID,这里简化处理:用内容 hash + doc_to_score: Dict[str, float] = {} + doc_map: Dict[str, Document] = {} + + for docs in doc_lists: + for rank, doc in enumerate(docs, start=1): + # 生成唯一标识符(内容+来源组合,避免不同文件相同内容混淆) + doc_id = f"{doc.page_content[:200]}_{doc.metadata.get('source', '')}" + if doc_id not in doc_map: + doc_map[doc_id] = doc + score = doc_to_score.get(doc_id, 0.0) + 1.0 / (k + rank) + doc_to_score[doc_id] = score + + # 按得分排序 + sorted_ids = sorted(doc_to_score.keys(), key=lambda x: doc_to_score[x], reverse=True) + return [doc_map[doc_id] for doc_id in sorted_ids] \ No newline at end of file diff --git a/backend/app/rag/pipeline.py b/backend/app/rag/pipeline.py new file mode 100644 index 0000000..41f4186 --- /dev/null +++ b/backend/app/rag/pipeline.py @@ -0,0 +1,91 @@ +# rag/pipeline.py + +import asyncio +import os +from ..config import LLAMACPP_RERANKER_URL, LLAMACPP_API_KEY +from typing import List +from langchain_core.documents import Document +from langchain_core.language_models import BaseLanguageModel + +from .reranker import LLaMaCPPReranker +from .query_transform import MultiQueryGenerator +from .fusion import reciprocal_rank_fusion + +class RAGPipeline: + """ + 固定流程的 RAG 检索流水线: + 多路改写 → 并行检索 → RRF融合 → 重排序 → 返回父文档 + """ + + def __init__( + self, + retriever, # 基础检索器(应返回父文档,例如 ParentDocumentRetriever 实例) + llm: BaseLanguageModel, + num_queries: int = 3, + rerank_top_n: int = 5, + ): + """ + Args: + retriever: 基础检索器对象,需实现 ainvoke(query) 异步方法 + llm: 用于生成多路查询的语言模型 + num_queries: 生成的查询变体数量 + rerank_top_n: 最终返回的文档数量 + rerank_model: 重排序模型名称 + """ + self.retriever = retriever + self.llm = llm + self.num_queries = num_queries + self.rerank_top_n = rerank_top_n + + # 初始化组件 + self.query_generator = MultiQueryGenerator(llm=llm, num_queries=num_queries) + self.reranker = LLaMaCPPReranker( + base_url=LLAMACPP_RERANKER_URL, + api_key=LLAMACPP_API_KEY, + top_n=rerank_top_n, + ) + + async def aretrieve(self, query: str) -> List[Document]: + """ + 异步执行完整检索流程 + """ + # Step 1: 生成多路查询 + queries = await self.query_generator.agenerate(query) + # 包含原始查询,确保至少有一条 + if query not in queries: + queries.insert(0, query) + else: + # 如果原始查询已在列表中,将其移至首位 + queries.remove(query) + queries.insert(0, query) + + # Step 2: 并行检索(每个查询获取文档列表) + tasks = [self.retriever.ainvoke(q) for q in queries] + doc_lists = await asyncio.gather(*tasks) + + # Step 3: RRF 融合 + fused_docs = reciprocal_rank_fusion(doc_lists) + + # Step 4: 重排序 + try: + final_docs = self.reranker.compress_documents(fused_docs, query) + except Exception: + # 若重排序器不可用,直接返回融合后的前 N 条 + final_docs = fused_docs[:self.rerank_top_n] + + return final_docs + + def retrieve(self, query: str) -> List[Document]: + """同步检索入口(内部调用异步方法)""" + return asyncio.run(self.aretrieve(query)) + + def format_context(self, documents: List[Document]) -> str: + """将文档列表格式化为上下文字符串""" + if not documents: + return "" + + parts = [] + for i, doc in enumerate(documents, 1): + source = doc.metadata.get("source", "未知来源") + parts.append(f"【资料 {i}】来源:{source}\n{doc.page_content}\n---\n") + return "\n".join(parts) \ No newline at end of file diff --git a/backend/app/rag/query_transform.py b/backend/app/rag/query_transform.py new file mode 100644 index 0000000..38f9fd1 --- /dev/null +++ b/backend/app/rag/query_transform.py @@ -0,0 +1,43 @@ +# rag/query_transform.py + +from typing import List +from langchain_core.language_models import BaseLanguageModel +from langchain_core.prompts import PromptTemplate + +MULTI_QUERY_PROMPT = PromptTemplate.from_template( + """你是一个专业的查询改写助手。你的任务是将用户的问题改写成 {num_queries} 个不同的版本。 +这些版本应该从不同的角度、使用不同的关键词来表达相同或相关的意图。 + +原始问题: {question} + +请生成 {num_queries} 个不同版本的查询,每个版本一行。 +确保每个版本都是独立、完整的查询语句。 + +生成 {num_queries} 个查询:""" +) + +class MultiQueryGenerator: + """多路查询生成器(不依赖 LangChain 的 MultiQueryRetriever)""" + + def __init__(self, llm: BaseLanguageModel, num_queries: int = 3): + self.llm = llm + self.num_queries = num_queries + self.prompt = MULTI_QUERY_PROMPT + + def generate(self, query: str) -> List[str]: + """同步生成多个查询变体""" + prompt_str = self.prompt.format(num_queries=self.num_queries, question=query) + response = self.llm.invoke(prompt_str) + # 处理响应内容,按行分割并去除空行和首尾空白 + lines = response.content.strip().split('\n') + queries = [line.strip() for line in lines if line.strip()] + # 确保至少返回原始查询 + return queries[:self.num_queries] if queries else [query] + + async def agenerate(self, query: str) -> List[str]: + """异步生成多个查询变体""" + prompt_str = self.prompt.format(num_queries=self.num_queries, question=query) + response = await self.llm.ainvoke(prompt_str) + lines = response.content.strip().split('\n') + queries = [line.strip() for line in lines if line.strip()] + return queries[:self.num_queries] if queries else [query] \ No newline at end of file diff --git a/backend/app/rag/reranker.py b/backend/app/rag/reranker.py new file mode 100644 index 0000000..925e283 --- /dev/null +++ b/backend/app/rag/reranker.py @@ -0,0 +1,75 @@ +""" +重排序器模块 (适配版) +使用远程 llama.cpp 服务 (兼容 OpenAI Rerank API) 替代本地 Cross-Encoder +""" +import requests +from typing import List +from langchain_core.documents import Document + +class LLaMaCPPReranker: + """使用远程 llama.cpp 服务对检索结果重排序。""" + + def __init__(self, + base_url: str, + api_key: str, + top_n: int = 5, + timeout: int = 60): + """ + 初始化远程重排序器 + + Args: + base_url: llama.cpp 服务的地址和端口,默认为环境变量 LLAMACPP_RERANKER_URL 或 "http://127.0.0.1:8083"。 + top_n: 返回前 N 个结果。 + api_key: API 密钥,默认为环境变量 LLAMACPP_API_KEY 或 "huang1998"。 + timeout: 请求超时时间(秒)。 + """ + self.base_url = base_url + self.api_key = api_key + self.top_n = top_n + self.timeout = timeout + self.endpoint = f"{self.base_url}/rerank" + + def compress_documents( + self, documents: List[Document], query: str + ) -> List[Document]: + """ + 对文档进行重排序 + + Args: + documents: 待排序的文档列表 + query: 查询字符串 + + Returns: + 排序后的文档列表 + """ + if not documents: + return [] + + # 准备请求体 + # 根据 llama.cpp 的 OpenAI 兼容性,文档是一个字符串列表 + payload = { + "model": "bge-reranker-v2-m3", + "query": query, + "documents": [doc.page_content for doc in documents], + "top_n": self.top_n + } + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.api_key}" + } + + try: + response = requests.post(self.endpoint, json=payload, headers=headers, timeout=self.timeout) + response.raise_for_status() # 检查请求是否成功 + results = response.json() + + # 解析返回结果 + # 返回格式: {"results": [{"index": 0, "document": "...", "relevance_score": 0.8}, ...]} + # 按相关性得分降序排列 + sorted_indices = [item["index"] for item in results["results"]] + sorted_docs = [documents[idx] for idx in sorted_indices] + return sorted_docs + + except Exception as e: + print(f"警告: 远程重排序过程出错,将使用原始排序。错误: {e}") + return documents[:self.top_n] \ No newline at end of file diff --git a/backend/app/rag/retriever.py b/backend/app/rag/retriever.py new file mode 100644 index 0000000..483c8b9 --- /dev/null +++ b/backend/app/rag/retriever.py @@ -0,0 +1,199 @@ +""" +Qdrant 向量检索器模块 + +提供基于 Qdrant 的基础向量检索和混合检索(Dense + Sparse)功能。 + +核心原理: +- 基础检索:将查询文本转换为向量,在 Qdrant 中进行近似最近邻(ANN)搜索, + 使用余弦相似度返回最相似的 k 个文档。 +- 混合检索:结合稠密向量检索(语义相似)和 BM25 稀疏向量检索(关键词匹配), + 通过加权或分数融合提高召回精度。 + +使用示例: + >>> from rag_core import LlamaCppEmbedder + >>> embedder = LlamaCppEmbedder() + >>> embeddings = embedder.as_langchain_embeddings() + >>> + >>> # 创建基础检索器 + >>> retriever = create_base_retriever( + ... collection_name="my_docs", + ... embeddings=embeddings, + ... search_kwargs={"k": 10} + ... ) + >>> + >>> # 执行检索 + >>> docs = retriever.invoke("什么是 RAG?") +""" + +from typing import Optional, Dict, Any +from qdrant_client import QdrantClient +from qdrant_client.http.exceptions import UnexpectedResponse +from langchain_qdrant import QdrantVectorStore +from langchain_core.embeddings import Embeddings +from langchain_core.retrievers import BaseRetriever + +from rag_core import QDRANT_URL, QDRANT_API_KEY + +# 模块级常量 +DEFAULT_SEARCH_K = 20 +DEFAULT_SCORE_THRESHOLD = 0.3 + + +def create_qdrant_client( + url: Optional[str] = None, + api_key: Optional[str] = None, + timeout: int = 30, +) -> QdrantClient: + """ + 创建并返回一个配置好的 Qdrant 客户端。 + + 优先使用传入参数,若未提供则回退到环境变量 QDRANT_URL 和 QDRANT_API_KEY。 + + Args: + url: Qdrant 服务地址,例如 "http://localhost:6333"。 + 默认从环境变量 QDRANT_URL 读取。 + api_key: API 密钥(若 Qdrant 启用了认证)。 + 默认从环境变量 QDRANT_API_KEY 读取。 + timeout: 请求超时时间(秒),默认 30 秒。 + + Returns: + 配置好的 QdrantClient 实例。 + + Raises: + ValueError: 如果 url 为空且环境变量也未设置。 + """ + effective_url = url or QDRANT_URL + if not effective_url: + raise ValueError( + "Qdrant URL 未提供,请设置参数 url 或环境变量 QDRANT_URL" + ) + + effective_api_key = api_key or QDRANT_API_KEY + + client_kwargs = { + "url": effective_url, + "timeout": timeout, + } + if effective_api_key: + client_kwargs["api_key"] = effective_api_key + + return QdrantClient(**client_kwargs) + + +def create_base_retriever( + collection_name: str, + embeddings: Embeddings, + search_kwargs: Optional[Dict[str, Any]] = None, + client: Optional[QdrantClient] = None, +) -> BaseRetriever: + """ + 创建基础向量检索器(仅稠密向量检索)。 + + 该检索器使用嵌入模型将查询转为向量,在 Qdrant 集合中执行 ANN 搜索, + 返回语义上最相似的文档块。 + + Args: + collection_name: Qdrant 集合名称(需预先创建并索引)。 + embeddings: LangChain 兼容的嵌入模型实例。 + search_kwargs: 搜索参数,可包含: + - k (int): 返回的文档数量,默认 20。 + - score_threshold (float): 相似度阈值,仅返回高于此分数的文档。 + - filter (dict): Qdrant 过滤条件。 + 若为 None,则使用默认值 {"k": 20}。 + client: 可选的 Qdrant 客户端实例。若未提供,将自动创建。 + + Returns: + BaseRetriever 实例,可直接调用 .invoke(query) 或 .ainvoke(query) 检索。 + + Raises: + ValueError: 如果集合不存在或嵌入模型无效。 + """ + # 合并默认搜索参数 + merged_search_kwargs = {"k": DEFAULT_SEARCH_K} + if search_kwargs: + merged_search_kwargs.update(search_kwargs) + + # 创建或复用 Qdrant 客户端 + if client is None: + client = create_qdrant_client() + + # 验证集合是否存在(可选,便于提前发现问题) + try: + client.get_collection(collection_name) + except UnexpectedResponse as e: + if e.status_code == 404: + raise ValueError( + f"Qdrant 集合 '{collection_name}' 不存在,请先创建并索引文档。" + ) + raise + + # 构建向量存储 + vector_store = QdrantVectorStore( + client=client, + collection_name=collection_name, + embedding=embeddings, + ) + + # 返回检索器 + return vector_store.as_retriever(search_kwargs=merged_search_kwargs) + + +def create_hybrid_retriever( + collection_name: str, + embeddings: Embeddings, + dense_k: int = 10, + sparse_k: int = 10, + score_threshold: Optional[float] = DEFAULT_SCORE_THRESHOLD, + client: Optional[QdrantClient] = None, +) -> BaseRetriever: + """ + 创建混合检索器(稠密向量 + BM25 稀疏向量)。 + + 混合检索结合了语义相似度(Dense)和关键词匹配(Sparse), + 能够更好地处理专有名词、精确匹配等场景。 + + 注意:此功能要求 Qdrant 集合已配置稀疏向量字段并生成了 BM25 索引。 + 若集合未配置稀疏向量,将回退到纯稠密检索(不会报错,但检索效果降级)。 + + Args: + collection_name: Qdrant 集合名称。 + embeddings: 嵌入模型(用于稠密向量)。 + dense_k: 稠密向量检索返回数量,默认 10。 + sparse_k: 稀疏向量检索返回数量,默认 10。 + score_threshold: 相似度阈值,默认 0.3。 + client: 可选的 Qdrant 客户端实例。 + + Returns: + BaseRetriever 实例,配置了混合搜索参数。 + """ + total_k = dense_k + sparse_k + + search_kwargs = { + "k": total_k, + } + if score_threshold is not None: + search_kwargs["score_threshold"] = score_threshold + + # 复用基础检索器创建逻辑,只需调整搜索参数 + return create_base_retriever( + collection_name=collection_name, + embeddings=embeddings, + search_kwargs=search_kwargs, + client=client, + ) + + +# 可选:提供异步友好的辅助函数 +async def acreate_base_retriever( + collection_name: str, + embeddings: Embeddings, + search_kwargs: Optional[Dict[str, Any]] = None, + client: Optional[QdrantClient] = None, +) -> BaseRetriever: + """ + 异步创建基础向量检索器(与同步版本功能相同)。 + + 适用于需要异步初始化的场景(例如在 FastAPI 启动事件中)。 + """ + # 由于 QdrantVectorStore 初始化本身是同步的,这里直接调用同步版本即可 + return create_base_retriever(collection_name, embeddings, search_kwargs, client) \ No newline at end of file diff --git a/backend/app/rag/test.py b/backend/app/rag/test.py new file mode 100644 index 0000000..5e5d19e --- /dev/null +++ b/backend/app/rag/test.py @@ -0,0 +1,146 @@ +#!/usr/bin/env python3 +""" +RAG 系统使用示例(重构版) + +演示: +1. 使用 IndexBuilder 获取父子块检索器 +2. 创建固定流程的 RAGPipeline(多路改写 → RRF融合 → 重排序 → 返回父文档) +3. 将流水线封装为 LangChain 工具,供 Agent 调用 +""" + +import asyncio +import sys +import os + +from dotenv import load_dotenv + +# 加载环境变量(Qdrant URL、PostgreSQL 连接等) +load_dotenv() + +# 添加项目根目录到路径 +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../..")) + +from rag_indexer.index_builder import IndexBuilderConfig +from rag_indexer.splitters import SplitterType +from .pipeline import RAGPipeline +from .tools import create_rag_tool_sync +from pydantic import SecretStr +# 使用本地 LLM(通过 OpenAI 兼容接口) +from langchain_openai import ChatOpenAI +from rag_core.retriever_factory import create_parent_retriever + +load_dotenv() + +def create_llm(): + """创建本地 vLLM 服务 LLM""" + vllm_base_url = os.getenv( + "VLLM_BASE_URL", + "http://127.0.0.1:8081/v1" + ) + + return ChatOpenAI( + base_url=vllm_base_url, + api_key=SecretStr(os.getenv("LLAMACPP_API_KEY", "token-abc123")), + model="gemma-4-E2B-it", + timeout=60.0, # 请求超时时间(秒) + max_retries=2, # 失败后自动重试次数 + streaming=True, # 确保开启流式输出 + ) + +async def demonstrate_full_pipeline(): + """ + 完整流水线演示: + - 从 IndexBuilder 获取 ParentDocumentRetriever + - 创建 RAGPipeline + - 执行检索并打印结果 + """ + print("=" * 60) + print("演示:固定流程 RAG 检索(多路改写 + RRF + 重排序 + 父文档)") + print("=" * 60) + + retriever = create_parent_retriever(collection_name="rag_documents", search_k=5) + + if retriever is None: + print("错误:检索器未初始化,请确保索引已构建。") + return + + # 3. 创建 LLM 用于查询改写 + llm = create_llm() + + # 4. 创建 RAGPipeline(固定流程) + pipeline = RAGPipeline( + retriever=retriever, + llm=llm, + num_queries=3, # 生成 3 个查询变体 + rerank_top_n=5, # 最终返回 5 个父文档 + ) + + # 5. 执行检索 + query = "打虎英雄是谁?" + print(f"\n查询: {query}") + print("-" * 40) + + try: + documents = await pipeline.aretrieve(query) + print(f"返回 {len(documents)} 个父文档\n") + + # 打印结果预览 + for i, doc in enumerate(documents, 1): + content_preview = doc.page_content.replace("\n", " ")[:150] + source = doc.metadata.get("source", "未知来源") + print(f"{i}. 【来源:{source}】") + print(f" {content_preview}...\n") + + # 可选:格式化完整上下文 + # context = pipeline.format_context(documents) + # print(context) + + except Exception as e: + print(f"检索失败: {e}") + import traceback + traceback.print_exc() + +async def demonstrate_tool_creation(): + """ + 演示创建 RAG 工具(供 Agent 使用) + """ + print("\n" + "=" * 60) + print("演示:创建 RAG 工具(供 LangGraph Agent 调用)") + print("=" * 60) + + # 1. 获取检索器(同上) + config = IndexBuilderConfig( + collection_name="rag_documents", + splitter_type=SplitterType.PARENT_CHILD, + ) + retriever = retriever = create_parent_retriever(collection_name="rag_documents", search_k=5) + + # 2. 创建 LLM + llm = create_llm() + + # 3. 创建工具 + rag_tool = create_rag_tool_sync( + retriever=retriever, + llm=llm, + num_queries=3, + rerank_top_n=5, + collection_name="rag_documents", + ) + + print(f"工具名称: {rag_tool.name}") + print(f"工具描述: {rag_tool.description[:100]}...") + + # 4. 模拟 Agent 调用工具 + query = "请告诉我 打虎英雄是谁?" + print(f"\n模拟调用: {query}") + print("-" * 40) + + result = await rag_tool.ainvoke({"query": query}) + print(result[:800] + "..." if len(result) > 800 else result) + +async def main(): + await demonstrate_full_pipeline() + await demonstrate_tool_creation() + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/backend/app/rag/tools.py b/backend/app/rag/tools.py new file mode 100644 index 0000000..33f79c2 --- /dev/null +++ b/backend/app/rag/tools.py @@ -0,0 +1,54 @@ +""" +RAG 工具模块 + +将检索功能封装为 LangChain Tool,供 Agent 调用。 +采用固定流水线:多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档。 +""" +from typing import Callable +from langchain_core.tools import tool +from langchain_core.language_models import BaseLanguageModel +from langchain_core.retrievers import BaseRetriever +from .pipeline import RAGPipeline + +def create_rag_tool_sync( + retriever: BaseRetriever, + llm: BaseLanguageModel, + num_queries: int = 3, + rerank_top_n: int = 5, + collection_name: str = "rag_documents", +) -> Callable: + """ + 创建一个配置好的 RAG 检索工具(同步版本,用于不支持异步的旧版 Agent)。 + + 参数同 create_rag_tool。 + """ + pipeline = RAGPipeline( + retriever=retriever, + llm=llm, + num_queries=num_queries, + rerank_top_n=rerank_top_n, + ) + + @tool + def search_knowledge_base_sync(query: str) -> str: + """在知识库中搜索与查询相关的文档片段(同步版本)。 + + 功能与异步版本相同:多路改写 → RRF融合 → 重排序 → 返回父文档。 + + Args: + query: 用户提出的问题或查询字符串 + + Returns: + 格式化后的相关文档内容。 + """ + try: + documents = pipeline.retrieve(query) # 内部调用异步方法并等待 + if not documents: + return f"在知识库 '{collection_name}' 中未找到与 '{query}' 相关的信息。" + + context = pipeline.format_context(documents) + return context + except Exception as e: + return f"检索过程中发生错误: {str(e)}" + + return search_knowledge_base_sync \ No newline at end of file diff --git a/backend/app/test_backend.py b/backend/app/test_backend.py new file mode 100644 index 0000000..6af60d2 --- /dev/null +++ b/backend/app/test_backend.py @@ -0,0 +1,304 @@ +#!/usr/bin/env python3 +""" +完整后端测试 - 验证 Agent 所有功能 +包括:短期记忆、长期记忆、工具调用、流式对话、历史查询 +""" + +import asyncio +import os +from .config import DB_URI +import sys +import uuid +from dotenv import load_dotenv + +# 添加项目根目录到 Python 路径 (现在文件在 backend/app/ 下,backend 就是根) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) + +load_dotenv() + +from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver +from ..agent import AIAgentService +from ..agent.history import ThreadHistoryService +from ..logger import info, warning, error + +# PostgreSQL 连接字符串 + +async def print_section(title): + """打印测试区块标题""" + print("\n" + "=" * 70) + print(f" {title}") + print("=" * 70) + +async def test_short_term_memory(agent_service): + """测试短期记忆(同一 thread_id 继续对话)""" + await print_section("测试 1: 短期记忆(Short-term Memory)") + + thread_id = str(uuid.uuid4()) + user_id = "test_user_memory" + + print(f"\n使用 thread_id: {thread_id[:8]}...") + print(f"使用 user_id: {user_id}") + + # 第一轮对话 + print("\n[第一轮] 发送消息: '我叫张三,今年28岁'") + result1 = await agent_service.process_message( + "我叫张三,今年28岁", thread_id, "local", user_id + ) + print(f"回复: {result1['reply'][:100]}...") + + # 第二轮对话 - 测试记忆 + print("\n[第二轮] 发送消息: '我叫什么名字?今年多大?'") + result2 = await agent_service.process_message( + "我叫什么名字?今年多大?", thread_id, "local", user_id + ) + print(f"回复: {result2['reply']}") + + # 验证记忆是否存在 + if "张三" in result2['reply'] or "28" in result2['reply']: + print("\n✅ 短期记忆测试通过!") + return True + else: + print("\n❌ 短期记忆测试失败!") + return False + +async def test_tool_calling(agent_service): + """测试工具调用(RAG 搜索)""" + await print_section("测试 2: 工具调用(Tool Calling)") + + thread_id = str(uuid.uuid4()) + user_id = "test_user_tools" + + print(f"\n使用 thread_id: {thread_id[:8]}...") + print(f"使用 user_id: {user_id}") + + # 发送需要 RAG 搜索的问题 + print("\n发送消息: '请告诉我,打虎英雄是谁?'") + result = await agent_service.process_message( + "请告诉我,打虎英雄是谁?", thread_id, "local", user_id + ) + print(f"回复: {result['reply'][:200]}...") + + # 检查是否调用了 RAG 工具(回复中会有水浒传相关内容) + if "武松" in result['reply'] or "李忠" in result['reply'] or "水浒传" in result['reply']: + print("\n✅ 工具调用测试通过!") + return True + else: + print("\n⚠️ 工具调用测试结果不确定,需要手动验证") + return None + +async def test_streaming(agent_service): + """测试流式对话""" + await print_section("测试 3: 流式对话(Streaming)") + + thread_id = str(uuid.uuid4()) + user_id = "test_user_stream" + + print(f"\n使用 thread_id: {thread_id[:8]}...") + print(f"使用 user_id: {user_id}") + + print("\n发送消息: '用100字介绍一下AI人工智能' (流式)...") + print("流式输出: ", end="", flush=True) + + full_reply = "" + chunk_count = 0 + + try: + async for chunk in agent_service.process_message_stream( + "用100字介绍一下AI人工智能", thread_id, "local", user_id + ): + chunk_count += 1 + if chunk.get("type") == "llm_token": + token = chunk.get("token", "") + print(token, end="", flush=True) + full_reply += token + elif chunk.get("type") == "state_update": + pass # 状态更新不显示 + + print(f"\n\n共收到 {chunk_count} 个 chunk") + print(f"完整回复长度: {len(full_reply)} 字") + + if chunk_count > 0 and len(full_reply) > 10: + print("\n✅ 流式对话测试通过!") + return True + else: + print("\n❌ 流式对话测试失败!") + return False + + except Exception as e: + print(f"\n❌ 流式对话异常: {e}") + return False + +async def test_history_service(agent_service, history_service): + """测试历史查询服务""" + await print_section("测试 4: 历史查询服务(History Service)") + + user_id = "test_user_history" + + # 先创建几个对话 + print(f"\n为 user_id={user_id} 创建测试对话...") + + thread_ids = [] + for i in range(3): + thread_id = str(uuid.uuid4()) + thread_ids.append(thread_id) + + await agent_service.process_message( + f"这是第 {i+1} 个测试对话", thread_id, "local", user_id + ) + print(f" 创建线程 {i+1}: {thread_id[:8]}...") + + # 1. 测试获取用户线程列表 + print("\n[4.1] 测试获取用户线程列表...") + threads = await history_service.get_user_threads(user_id, limit=10) + print(f" 找到 {len(threads)} 个线程") + + if len(threads) >= 3: + print(" ✅ 线程列表查询通过") + else: + print(" ⚠️ 线程数量少于预期") + + # 2. 测试获取单个线程的消息历史 + if thread_ids: + test_thread_id = thread_ids[0] + print(f"\n[4.2] 测试获取线程消息历史 (thread_id={test_thread_id[:8]}...)") + messages = await history_service.get_thread_messages(test_thread_id) + print(f" 找到 {len(messages)} 条消息") + + if len(messages) >= 2: # 至少有一问一答 + print(" ✅ 消息历史查询通过") + else: + print(" ⚠️ 消息数量少于预期") + + # 3. 测试获取线程摘要 + print(f"\n[4.3] 测试获取线程摘要...") + summary = await history_service.get_thread_summary(test_thread_id) + print(f" 摘要: {summary.get('summary', '')[:50]}...") + print(f" 消息数: {summary.get('message_count', 0)}") + + if summary.get('message_count', 0) > 0: + print(" ✅ 线程摘要查询通过") + else: + print(" ⚠️ 摘要查询结果不确定") + + return len(threads) >= 3 + +async def test_long_term_memory(agent_service): + """测试长期记忆(mem0)""" + await print_section("测试 5: 长期记忆(Long-term Memory - mem0)") + + thread_id1 = str(uuid.uuid4()) + thread_id2 = str(uuid.uuid4()) # 不同的线程 + user_id = "test_user_longterm" + + print(f"\n使用 user_id: {user_id}") + print(f"线程 1: {thread_id1[:8]}...") + print(f"线程 2: {thread_id2[:8]}...") + + # 在第一个线程中保存信息 + print("\n[线程 1] 发送消息: '记住,我的宠物名字叫小白,是一只猫'") + result1 = await agent_service.process_message( + "记住,我的宠物名字叫小白,是一只猫", thread_id1, "local", user_id + ) + print(f"回复: {result1['reply'][:100]}...") + + # 等待一下,让 mem0 保存 + await asyncio.sleep(1) + + # 在第二个线程中询问(不同的 thread_id) + print("\n[线程 2] 发送消息: '我的宠物叫什么名字?是什么动物?'") + result2 = await agent_service.process_message( + "我的宠物叫什么名字?是什么动物?", thread_id2, "local", user_id + ) + print(f"回复: {result2['reply']}") + + # 验证长期记忆 + if "小白" in result2['reply'] or "猫" in result2['reply']: + print("\n✅ 长期记忆测试通过!") + return True + else: + print("\n⚠️ 长期记忆可能未启用,或需要手动验证") + return None + +async def main(): + """主测试函数""" + print("\n" + "=" * 70) + print(" 后端完整功能测试") + print("=" * 70) + + results = {} + + try: + # 创建数据库连接和服务 + print("\n正在初始化数据库连接...") + async with AsyncPostgresSaver.from_conn_string(DB_URI) as checkpointer: + await checkpointer.setup() + print("✅ 数据库连接成功") + + # 创建服务实例 + print("\n正在初始化 Agent 服务...") + agent_service = AIAgentService(checkpointer) + await agent_service.initialize() + print("✅ Agent 服务初始化成功") + + history_service = ThreadHistoryService(checkpointer) + print("✅ 历史服务初始化成功") + + print(f"\n可用模型: {list(agent_service.graphs.keys())}") + + # 运行测试 + results["短期记忆"] = await test_short_term_memory(agent_service) + await asyncio.sleep(1) + + results["工具调用"] = await test_tool_calling(agent_service) + await asyncio.sleep(1) + + results["流式对话"] = await test_streaming(agent_service) + await asyncio.sleep(1) + + results["历史查询"] = await test_history_service(agent_service, history_service) + await asyncio.sleep(1) + + results["长期记忆"] = await test_long_term_memory(agent_service) + await asyncio.sleep(1) + + # 打印总结 + await print_section("测试总结") + print("\n测试结果:") + print("-" * 40) + + pass_count = 0 + fail_count = 0 + skip_count = 0 + + for test_name, result in results.items(): + if result is True: + status = "✅ 通过" + pass_count += 1 + elif result is False: + status = "❌ 失败" + fail_count += 1 + else: + status = "⚠️ 待验证" + skip_count += 1 + print(f" {test_name:12s}: {status}") + + print("-" * 40) + print(f"总计: {len(results)} 个测试") + print(f"通过: {pass_count}, 失败: {fail_count}, 待验证: {skip_count}") + + if fail_count == 0: + print("\n🎉 所有核心测试通过!") + else: + print(f"\n⚠️ 有 {fail_count} 个测试失败") + + except Exception as e: + error(f"\n❌ 测试运行异常: {e}") + import traceback + traceback.print_exc() + return 1 + + return 0 if fail_count == 0 else 1 + +if __name__ == "__main__": + exit_code = asyncio.run(main()) + sys.exit(exit_code) diff --git a/backend/app/utils/__init__.py b/backend/app/utils/__init__.py new file mode 100644 index 0000000..d8bd0f1 --- /dev/null +++ b/backend/app/utils/__init__.py @@ -0,0 +1,7 @@ +""" +工具模块 +""" + +from .logging import log_state_change, print_llm_input + +__all__ = ["log_state_change", "print_llm_input"] diff --git a/backend/app/utils/logging.py b/backend/app/utils/logging.py new file mode 100644 index 0000000..ae513fd --- /dev/null +++ b/backend/app/utils/logging.py @@ -0,0 +1,61 @@ +""" +LangGraph 节点日志工具模块 +提供状态流转追踪和 LLM 输入输出打印功能 +""" + +from ..config import ENABLE_GRAPH_TRACE +from ..logger import debug, info + + +def log_state_change(node_name: str, state: dict, prefix: str = "进入"): + """ + 记录状态变化日志 + + Args: + node_name: 节点名称 + state: 当前状态 + prefix: 日志前缀("进入" 或 "离开") + """ + from app.logger import info + + messages = state.get("messages", []) + msg_count = len(messages) + last_msg = messages[-1] if messages else None + last_info = "" + if last_msg: + # 兼容 dict 和对象两种格式 + if isinstance(last_msg, dict): + content_preview = str(last_msg.get("content", ""))[:10].replace("\n", " ") + msg_type = last_msg.get("type", "unknown") + else: + content_preview = str(last_msg.content)[:10].replace("\n", " ") + msg_type = getattr(last_msg, 'type', 'unknown') + last_info = f"{msg_type.upper()}: {content_preview}" + info(f"🔄 [{node_name}] {prefix} | 消息数:{msg_count} | 最后一条:{last_info}") + + +def print_llm_input(prompt_value): + """ + RunnableLambda 回调函数:打印格式化后发送给 LLM 的完整消息 + + Args: + prompt_value: ChatPromptValue 对象,包含格式化后的消息列表 + + Returns: + 原样返回 prompt_value,不影响链式调用 + """ + if not ENABLE_GRAPH_TRACE: + return prompt_value + + messages = prompt_value.messages # ChatPromptValue 提供 .messages 属性 + + debug("\n" + "=" * 80) + debug("📤 [LLM输入] 格式化后发送给大模型的完整消息:") + debug(f" 总消息数: {len(messages)}") + debug("-" * 80) + for i, msg in enumerate(messages): + content_preview = str(msg.content) # 完整输出 + debug(f" [{i}] {msg.type.upper():10s}: {content_preview}") + debug("\n" + "=" * 80 + "\n") + + return prompt_value diff --git a/backend/rag_core/__init__.py b/backend/rag_core/__init__.py new file mode 100644 index 0000000..318a066 --- /dev/null +++ b/backend/rag_core/__init__.py @@ -0,0 +1,21 @@ +""" +RAG Core - 公共 RAG 组件包 + +提供嵌入模型、向量存储和文档存储的公共功能,被 rag_indexer 和 app/rag 共用。 +""" + +from .embedders import LlamaCppEmbedder +from .vector_store import QdrantVectorStore, QDRANT_URL, QDRANT_API_KEY +from .store import PostgresDocStore, create_docstore +from .retriever_factory import create_parent_retriever + + +__all__ = [ + "LlamaCppEmbedder", + "QdrantVectorStore", + "QDRANT_URL", + "QDRANT_API_KEY", + "PostgresDocStore", + "create_docstore", + "create_parent_retriever", +] diff --git a/backend/rag_core/client.py b/backend/rag_core/client.py new file mode 100644 index 0000000..d689a1f --- /dev/null +++ b/backend/rag_core/client.py @@ -0,0 +1,27 @@ +# rag_core/client.py +import os +from .config import QDRANT_URL, QDRANT_API_KEY +from typing import Optional +from qdrant_client import QdrantClient + + + +def create_qdrant_client( + url: Optional[str] = None, + api_key: Optional[str] = None, + timeout: int = 300, # 索引构建需要较长超时 +) -> QdrantClient: + effective_url = url or QDRANT_URL + effective_api_key = api_key or QDRANT_API_KEY + + if not effective_url: + raise ValueError("Qdrant URL 未配置") + + client_kwargs = { + "url": effective_url, + "timeout": timeout, + } + if effective_api_key: + client_kwargs["api_key"] = effective_api_key + + return QdrantClient(**client_kwargs) \ No newline at end of file diff --git a/backend/rag_core/config.py b/backend/rag_core/config.py new file mode 100644 index 0000000..0d06575 --- /dev/null +++ b/backend/rag_core/config.py @@ -0,0 +1,24 @@ +""" +RAG Core 配置管理模块 +集中管理所有环境变量配置项,避免散落在各个文件中 +""" + +import os + +# ========== 向量数据库配置 ========== +QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333") +QDRANT_API_KEY = os.getenv("QDRANT_API_KEY", "") + +# ========== 嵌入服务配置 ========== +LLAMACPP_EMBEDDING_URL = os.getenv("LLAMACPP_EMBEDDING_URL", "http://127.0.0.1:8082") +LLAMACPP_API_KEY = os.getenv("LLAMACPP_API_KEY", "") + +# ========== 文档存储配置 ========== +DB_URI = os.getenv( + "DB_URI", + "postgresql://postgres:***@ai-postgres:5432/langgraph_db?sslmode=disable" +) +DOCSTORE_URI = os.getenv("DOCSTORE_URI", DB_URI) + +# ========== 其他配置 ========== +# 可以在此添加其他 RAG Core 专用的配置项 \ No newline at end of file diff --git a/backend/rag_core/embedders.py b/backend/rag_core/embedders.py new file mode 100644 index 0000000..eff11a0 --- /dev/null +++ b/backend/rag_core/embedders.py @@ -0,0 +1,83 @@ +""" +嵌入模型包装器,用于 llama.cpp 服务。 +""" + +import os +from .config import LLAMACPP_EMBEDDING_URL, LLAMACPP_API_KEY +import httpx +from typing import List, Optional + +from langchain_core.embeddings import Embeddings + +class LlamaCppEmbedder: + """通过 OpenAI 兼容 API 封装 llama.cpp 嵌入服务。""" + + def __init__( + self, + base_url: Optional[str] = None, + api_key: Optional[str] = None, + model: str = "Qwen3-Embedding-0.6B-Q8_0", + ): + self.base_url = base_url or LLAMACPP_EMBEDDING_URL + self.api_key = api_key or LLAMACPP_API_KEY + self.model = model + + def as_langchain_embeddings(self) -> Embeddings: + """创建 LangChain 兼容的嵌入实例。""" + return _LlamaCppLangchainAdapter(self) + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """嵌入一批文档。""" + return self._call_embedding_api(texts) + + def embed_query(self, text: str) -> List[float]: + """嵌入单个查询。""" + return self._call_embedding_api([text])[0] + + def get_embedding_dimension(self) -> int: + """通过嵌入测试字符串获取嵌入维度。""" + test_embedding = self.embed_query("test") + return len(test_embedding) + + def _call_embedding_api(self, texts: List[str]) -> List[List[float]]: + """直接调用 llama.cpp 嵌入 API。""" + base = self.base_url.rstrip("/") + if not base.endswith("/v1"): + base = base + "/v1" + + headers = {"Content-Type": "application/json"} + if self.api_key: + headers["Authorization"] = f"Bearer {self.api_key}" + + payload = { + "input": texts, + "model": self.model, + } + + with httpx.Client(timeout=120) as client: + response = client.post( + f"{base}/embeddings", + headers=headers, + json=payload, + ) + response.raise_for_status() + data = response.json() + + if isinstance(data, list): + return [item["embedding"] for item in data] + elif isinstance(data, dict) and "data" in data: + return [item["embedding"] for item in sorted(data["data"], key=lambda x: x["index"])] + else: + raise ValueError(f"未知的嵌入 API 响应格式: {data}") + +class _LlamaCppLangchainAdapter(Embeddings): + """将 LlamaCppEmbedder 适配为 LangChain Embeddings 接口。""" + + def __init__(self, embedder: LlamaCppEmbedder): + self._embedder = embedder + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + return self._embedder.embed_documents(texts) + + def embed_query(self, text: str) -> List[float]: + return self._embedder.embed_query(text) \ No newline at end of file diff --git a/backend/rag_core/retriever_factory.py b/backend/rag_core/retriever_factory.py new file mode 100644 index 0000000..25dab1c --- /dev/null +++ b/backend/rag_core/retriever_factory.py @@ -0,0 +1,59 @@ +# rag_core/retriever_factory.py +from langchain_core.embeddings import Embeddings +from langchain_classic.retrievers import ParentDocumentRetriever +from langchain_text_splitters import RecursiveCharacterTextSplitter +from typing import Optional +from langchain_core.embeddings import Embeddings +from langchain_core.stores import BaseStore +from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter +from langchain_classic.retrievers import ParentDocumentRetriever + +from rag_core import LlamaCppEmbedder, QdrantVectorStore, create_docstore + +def create_parent_retriever( + collection_name: str = "rag_documents", + embeddings: Optional[Embeddings] = None, + parent_splitter: Optional[TextSplitter] = None, + child_splitter: Optional[TextSplitter] = None, + docstore: Optional[BaseStore] = None, + search_k: int = 5, + # 若未传入切分器,则用以下参数创建默认切分器 + parent_chunk_size: int = 1000, + parent_chunk_overlap: int = 100, + child_chunk_size: int = 200, + child_chunk_overlap: int = 20, +) -> ParentDocumentRetriever: + # 嵌入模型 + if embeddings is None: + embedder = LlamaCppEmbedder() + embeddings = embedder.as_langchain_embeddings() + + # 向量存储(只读) + vector_store = QdrantVectorStore( + collection_name=collection_name, + embeddings=embeddings, + ) + + # 切分器(若未提供则创建默认) + if parent_splitter is None: + parent_splitter = RecursiveCharacterTextSplitter( + chunk_size=parent_chunk_size, + chunk_overlap=parent_chunk_overlap, + ) + if child_splitter is None: + child_splitter = RecursiveCharacterTextSplitter( + chunk_size=child_chunk_size, + chunk_overlap=child_chunk_overlap, + ) + + # 文档存储 + if docstore is None: + docstore, _ = create_docstore() # 从环境变量读取连接 + + return ParentDocumentRetriever( + vectorstore=vector_store.get_langchain_vectorstore(), + docstore=docstore, + child_splitter=child_splitter, + parent_splitter=parent_splitter, + search_kwargs={"k": search_k}, + ) \ No newline at end of file diff --git a/backend/rag_core/store/__init__.py b/backend/rag_core/store/__init__.py new file mode 100644 index 0000000..359db76 --- /dev/null +++ b/backend/rag_core/store/__init__.py @@ -0,0 +1,31 @@ +""" +文档存储模块 - 用于 ParentDocumentRetriever 的父文档存储。 + +提供 PostgreSQL 存储后端: +- PostgresDocStore: PostgreSQL 数据库存储(生产环境) + +示例用法: + >>> from rag_core.store import create_docstore + + >>> # 创建 PostgreSQL 存储 + >>> store, conn = create_docstore( + ... connection_string="postgresql://user:pass@host:5432/db", + ... table_name="parent_docs" + ... ) +""" + + +from .postgres import PostgresDocStore +from .factory import create_docstore, get_docstore_uri, DEFAULT_DB_URI + +__version__ = "2.0.0" + +__all__ = [ + # 具体实现 + "PostgresDocStore", + + # 工厂函数 + "create_docstore", + "get_docstore_uri", + "DEFAULT_DB_URI", +] diff --git a/backend/rag_core/store/factory.py b/backend/rag_core/store/factory.py new file mode 100644 index 0000000..6c87ac9 --- /dev/null +++ b/backend/rag_core/store/factory.py @@ -0,0 +1,71 @@ +""" +文档存储工厂 - 创建不同类型的存储实例。 + +提供统一的接口来创建本地文件存储或 PostgreSQL 存储。 +""" + +import os +from ..config import DB_URI, DOCSTORE_URI +import logging +from typing import Optional, Tuple + +from langchain_core.stores import BaseStore +from .postgres import PostgresDocStore + +logger = logging.getLogger(__name__) + +# 默认连接字符串(从环境变量读取) +DEFAULT_DB_URI = DB_URI + + +def get_docstore_uri() -> str: + """获取 docstore 专用的数据库连接字符串(可与主库相同)""" + return DOCSTORE_URI + + +def create_docstore( + store_type: str = "postgres", + connection_string: Optional[str] = None, + table_name: str = "parent_documents", + pool_config: Optional[dict] = None, + max_concurrency: Optional[int] = None +) -> Tuple[BaseStore, Optional[str]]: + """ + 工厂函数,创建 PostgreSQL 文档存储。 + + Args: + store_type: 存储类型,目前仅支持 "postgres"(默认) + connection_string: PostgreSQL 连接字符串 + table_name: PostgreSQL 表名(默认:parent_documents) + pool_config: 连接池配置 + max_concurrency: 最大并发操作数,如果为 None 则不限制 + + Returns: + 元组 (存储实例, 连接字符串) + + Raises: + ValueError: 不支持的存储类型 + ImportError: 缺少必要的依赖 + + Example: + >>> # 创建 PostgreSQL 存储 + >>> store, conn = create_docstore( + ... connection_string="postgresql://user:pass@host:5432/db", + ... table_name="parent_docs", + ... max_concurrency=10 + ... ) + """ + store_type = store_type.lower() + + if store_type == "postgres": + conn_str = connection_string or get_docstore_uri() + store = PostgresDocStore( + connection_string=conn_str, + table_name=table_name, + pool_config=pool_config, + max_concurrency=max_concurrency + ) + return store, conn_str + + else: + raise ValueError(f"不支持的存储类型: {store_type}。目前仅支持: postgres") diff --git a/backend/rag_core/store/postgres.py b/backend/rag_core/store/postgres.py new file mode 100644 index 0000000..23b7153 --- /dev/null +++ b/backend/rag_core/store/postgres.py @@ -0,0 +1,246 @@ +""" +异步 PostgreSQL 存储实现 - 用于生产环境。 + +使用 asyncpg 实现真正的异步 PostgreSQL 文档存储,支持高并发访问。 +""" + +import asyncio +import json +import logging +from typing import List, Dict, Any, Optional, Iterator, Tuple, Sequence + +from langchain_core.documents import Document +from langchain_core.stores import BaseStore + +import asyncpg + +logger = logging.getLogger(__name__) + +class PostgresDocStore(BaseStore[str, Any]): + """ + 异步 PostgreSQL 文档存储实现。 + + 使用 asyncpg 作为异步 PostgreSQL 客户端,支持: + - 真正的异步操作 + - 连接池管理 + - 自动表创建 + - 批量操作(amget/amset/amdelete) + - JSONB 数据存储 + - 并发控制 + + 适用于生产环境,提供高性能的异步数据持久化。 + + Attributes: + dsn: PostgreSQL 连接字符串 + table_name: 存储表名,默认为 "parent_documents" + _pool: asyncpg 连接池实例 + _semaphore: 控制并发数的信号量(可选) + """ + + def __init__( + self, + connection_string: str, + table_name: str = "parent_documents", + pool_config: Optional[Dict[str, Any]] = None, + max_concurrency: Optional[int] = None + ): + """ + 初始化异步 PostgreSQL 文档存储。 + + Args: + connection_string: PostgreSQL 连接 URL,格式: + "postgresql://user:password@host:port/database?sslmode=disable" + table_name: 存储表名,默认为 "parent_documents" + pool_config: 连接池配置字典,包含: + - min_size: 最小连接数(默认 2) + - max_size: 最大连接数(默认 10) + max_concurrency: 最大并发操作数,如果为 None 则不限制 + + Raises: + ImportError: 未安装 asyncpg 时抛出 + + Example: + >>> store = PostgresDocStore( + ... "postgresql://user:pass@localhost:5432/mydb", + ... table_name="parent_docs", + ... pool_config={"min_size": 5, "max_size": 20}, + ... max_concurrency=10 + ... ) + """ + + + self.dsn = connection_string + self.table_name = table_name + self._pool: Optional["asyncpg.Pool"] = None + self._pool_config = pool_config or {} + + # 并发控制信号量 + self._semaphore = None + if max_concurrency is not None and max_concurrency > 0: + self._semaphore = asyncio.Semaphore(max_concurrency) + + # 注意:连接池的异步初始化延迟到第一次使用时 + # 表结构创建也延迟到第一次操作时 + + async def _get_pool(self): + """获取或创建 asyncpg 连接池。""" + if self._pool is None: + import asyncpg + min_size = self._pool_config.get("min_size", 2) + max_size = self._pool_config.get("max_size", 10) + + try: + self._pool = await asyncpg.create_pool( + dsn=self.dsn, + min_size=min_size, + max_size=max_size + ) + logger.info(f"PostgreSQL 异步连接池已创建: {self.table_name}") + + # 初始化表结构 + await self._create_table() + except Exception as e: + raise RuntimeError(f"PostgreSQL 异步连接池创建失败: {e}") + + return self._pool + + async def _create_table(self): + """创建存储表(如果不存在)。""" + pool = await self._get_pool() + async with pool.acquire() as conn: + async with conn.transaction(): + await conn.execute(f""" + CREATE TABLE IF NOT EXISTS {self.table_name} ( + key TEXT PRIMARY KEY, + value JSONB NOT NULL, + created_at TIMESTAMPTZ DEFAULT NOW() + ) + """) + logger.info(f"表 {self.table_name} 已就绪") + + async def _with_concurrency_control(self, coro): + """使用信号量控制并发执行。""" + if self._semaphore is None: + return await coro + async with self._semaphore: + return await coro + + # --- 同步方法(保持兼容性,但功能有限)--- + + def mget(self, keys: Sequence[str]) -> List[Optional[Any]]: + """不支持同步操作,请使用异步 amget 方法。""" + raise NotImplementedError("不支持同步操作,请使用异步 amget 方法。") + + def mset(self, key_value_pairs: Sequence[Tuple[str, Any]]) -> None: + """不支持同步操作,请使用异步 amset 方法。""" + raise NotImplementedError("不支持同步操作,请使用异步 amset 方法。") + + def mdelete(self, keys: Sequence[str]) -> None: + """不支持同步操作,请使用异步 amdelete 方法。""" + raise NotImplementedError("不支持同步操作,请使用异步 amdelete 方法。") + + def yield_keys(self, *, prefix: str | None = None) -> Iterator[str]: + """不支持同步操作,请使用异步 ayield_keys 方法。""" + raise NotImplementedError("不支持同步操作,请使用异步 ayield_keys 方法。") + + # --- 异步方法(真正的实现)--- + + async def amget(self, keys: Sequence[str]) -> List[Optional[Any]]: + """异步批量获取文档。""" + if not keys: + return [] + + async def _amget(): + pool = await self._get_pool() + async with pool.acquire() as conn: + rows = await conn.fetch( + f"SELECT key, value FROM {self.table_name} WHERE key = ANY($1)", + keys + ) + result_map = {} + for row in rows: + val = row['value'] + if isinstance(val, str): + val = json.loads(val) + if isinstance(val, dict) and 'page_content' in val: + result_map[row['key']] = Document(**val) + else: + result_map[row['key']] = val + return [result_map.get(key) for key in keys] + + return await self._with_concurrency_control(_amget()) + + async def amset(self, key_value_pairs: Sequence[Tuple[str, Any]]) -> None: + """异步批量设置文档。""" + if not key_value_pairs: + return + + async def _amset(): + pool = await self._get_pool() + async with pool.acquire() as conn: + async with conn.transaction(): + await conn.executemany( + f""" + INSERT INTO {self.table_name} (key, value) + VALUES ($1, $2) + ON CONFLICT (key) DO UPDATE SET value = EXCLUDED.value + """, + [ + (k, json.dumps(v.dict() if isinstance(v, Document) else v, ensure_ascii=False)) + for k, v in key_value_pairs + ] + ) + logger.debug(f"已异步批量设置 {len(key_value_pairs)} 个文档") + + await self._with_concurrency_control(_amset()) + + async def amdelete(self, keys: Sequence[str]) -> None: + """异步批量删除文档。""" + if not keys: + return + + async def _amdelete(): + pool = await self._get_pool() + async with pool.acquire() as conn: + async with conn.transaction(): + await conn.execute( + f"DELETE FROM {self.table_name} WHERE key = ANY($1)", + keys + ) + logger.debug(f"已异步批量删除 {len(keys)} 个文档") + + await self._with_concurrency_control(_amdelete()) + + async def ayield_keys(self, *, prefix: str | None = None) -> Iterator[str]: + """异步迭代所有键。 + + 注意:这是一个异步生成器,需要使用 async for 迭代。 + """ + pool = await self._get_pool() + async with pool.acquire() as conn: + if prefix: + rows = await conn.fetch( + f"SELECT key FROM {self.table_name} WHERE key LIKE $1 ORDER BY key", + f"{prefix}%" + ) + else: + rows = await conn.fetch( + f"SELECT key FROM {self.table_name} ORDER BY key" + ) + + for row in rows: + yield row['key'] + + async def aclose(self) -> None: + """异步关闭连接池,释放资源。""" + if self._pool: + await self._pool.close() + self._pool = None + logger.info("PostgreSQL 异步连接池已关闭") + + def close(self) -> None: + """同步关闭连接池(功能有限)。 + + 注意:在异步环境中,请使用 aclose 方法。 + """ + pass diff --git a/backend/rag_core/vector_store.py b/backend/rag_core/vector_store.py new file mode 100644 index 0000000..b2ecd20 --- /dev/null +++ b/backend/rag_core/vector_store.py @@ -0,0 +1,180 @@ +""" +Qdrant 向量数据库包装器。 +""" + +import logging +import os +from .config import QDRANT_URL, QDRANT_API_KEY +import time +from typing import List, Optional, Dict, Any + +from langchain_core.documents import Document +from langchain_qdrant import QdrantVectorStore as LangchainQdrantVS +from qdrant_client import QdrantClient +from qdrant_client.http.models import Distance, VectorParams +from httpx import RemoteProtocolError +from qdrant_client.http.exceptions import ResponseHandlingException +from .client import create_qdrant_client + +logger = logging.getLogger(__name__) + + + +class QdrantVectorStore: + """Qdrant 向量数据库操作包装器。""" + + def __init__( + self, + collection_name: str, + embeddings: Optional[Any] = None, + ): + self.collection_name = collection_name + self._client: Optional[QdrantClient] = None + self._connection_attempts = 0 + self._last_connection_time: Optional[float] = None + + if embeddings is None: + from rag_core.embedders import LlamaCppEmbedder + embedder = LlamaCppEmbedder() + self.embeddings = embedder.as_langchain_embeddings() + else: + self.embeddings = embeddings + + self.create_collection() + + self.vector_store = LangchainQdrantVS( + client=self.get_client(), + collection_name=self.collection_name, + embedding=self.embeddings, + ) + + def get_client(self) -> QdrantClient: + if self._client is None: + self._client = create_qdrant_client(timeout=300) + self._connection_attempts += 1 + self._last_connection_time = time.time() + logger.debug("Qdrant 客户端已创建 (第 %d 次连接)", self._connection_attempts) + return self._client + + def refresh_client(self): + """关闭旧连接,创建新连接。""" + if self._client is not None: + try: + self._client.close() + logger.debug("Qdrant 旧连接已关闭") + except Exception as e: + logger.warning("关闭 Qdrant 连接时出现异常: %s", e) + finally: + self._client = None + self._last_connection_time = None + + def check_connection_health(self) -> bool: + """检查连接健康状态,如果连接已失效则自动重建。""" + if self._client is None: + logger.info("Qdrant 客户端未初始化,将创建新连接") + return False + + try: + client = self.get_client() + client.get_collections() + logger.debug("Qdrant 连接健康检查通过") + return True + except (RemoteProtocolError, ConnectionError, OSError, ResponseHandlingException) as e: + logger.warning("Qdrant 连接健康检查失败: %s", e) + self.refresh_client() + return False + + def get_connection_stats(self) -> Dict[str, Any]: + """获取连接统计信息。""" + return { + "connection_attempts": self._connection_attempts, + "last_connection_time": self._last_connection_time, + "client_initialized": self._client is not None, + } + + def create_collection(self, vector_size: Optional[int] = None, force_recreate: bool = False): + """创建集合,设置合适的向量维度。""" + if vector_size is None: + from rag_core.embedders import LlamaCppEmbedder + embedder = LlamaCppEmbedder() + vector_size = embedder.get_embedding_dimension() + + max_retries = 3 + base_delay = 2 + for attempt in range(max_retries): + try: + client = self.get_client() + collections = client.get_collections().collections + exists = any(c.name == self.collection_name for c in collections) + + if exists and force_recreate: + client.delete_collection(self.collection_name) + exists = False + + if not exists: + client.create_collection( + collection_name=self.collection_name, + vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE), + ) + logger.info("集合 '%s' 已创建(维度=%d)", self.collection_name, vector_size) + else: + logger.info("集合 '%s' 已存在", self.collection_name) + return + except (RemoteProtocolError, ConnectionError, OSError, ResponseHandlingException) as e: + if attempt == max_retries - 1: + logger.error("创建集合 '%s' 重试 %d 次后仍然失败: %s", self.collection_name, max_retries, e) + raise + wait_time = base_delay * (2 ** attempt) + error_type = type(e).__name__ + logger.warning( + "创建集合 '%s' 遇到网络异常 [%s],%d秒后重试 (%d/%d): %s", + self.collection_name, error_type, wait_time, attempt + 1, max_retries, e + ) + self.refresh_client() + logger.debug("已刷新 Qdrant 客户端连接") + time.sleep(wait_time) + + def add_documents(self, documents: List[Document], batch_size: int = 100): + """将文档添加到向量数据库。""" + if not documents: + return [] + self.create_collection() + ids = self.vector_store.add_documents(documents, batch_size=batch_size) + logger.info("已向 '%s' 添加 %d 个文档", self.collection_name, len(ids)) + return ids + + def similarity_search(self, query: str, k: int = 5) -> List[Document]: + return self.vector_store.similarity_search(query, k=k) + + def similarity_search_with_score(self, query: str, k: int = 5) -> List[tuple[Document, float]]: + return self.vector_store.similarity_search_with_score(query, k=k) + + def delete_collection(self): + self.get_client().delete_collection(self.collection_name) + logger.info("集合 '%s' 已删除", self.collection_name) + + def get_collection_info(self) -> Dict[str, Any]: + info = self.get_client().get_collection(self.collection_name) + vectors_config = info.config.params.vectors + if isinstance(vectors_config, dict): + first_config = next(iter(vectors_config.values()), None) + vector_size = first_config.size if first_config else 0 + else: + vector_size = vectors_config.size if vectors_config else 0 + return { + "name": self.collection_name, + "vectors_count": info.points_count or 0, + "status": info.status, + "vector_size": vector_size, + } + + def as_langchain_vectorstore(self): + return self.vector_store + + def get_langchain_vectorstore(self): + """返回 LangChain Qdrant 向量存储对象(别名)""" + return self.vector_store + + def get_qdrant_client(self): + """返回原生 Qdrant 客户端(如需手动管理 collection)""" + return self.get_client() \ No newline at end of file diff --git a/backend/requirements.txt b/backend/requirements.txt new file mode 100644 index 0000000..e648caa --- /dev/null +++ b/backend/requirements.txt @@ -0,0 +1,46 @@ +# Core +pydantic==2.12.5 +python-dotenv==1.2.2 +typing-extensions==4.15.0 + +# LangChain +langchain==1.2.15 +langchain-community==0.4.1 +langchain-core==1.2.28 +langchain-openai==1.1.12 +langchain-qdrant==1.1.0 +langgraph==1.1.6 +langgraph-checkpoint-postgres==3.0.5 +tiktoken>=0.12.0 + +# Vector DB +qdrant-client==1.17.1 + +# Memory +mem0ai==1.0.11 + +# Backend +fastapi==0.135.3 +uvicorn[standard]==0.44.0 + +# Database +asyncpg==0.31.0 +psycopg[binary]==3.3.3 + +# HTTP +httpx==0.28.1 +aiohttp==3.13.5 + +# Utilities +tenacity==9.1.4 +rich==15.0.0 +PyYAML==6.0.3 +numpy>=1.26.2 + +# Document Processing +unstructured==0.22.21 +pypdf==6.10.0 +beautifulsoup4==4.14.3 +lxml==6.1.0 +pandas==3.0.2 # 若需Excel保留,否则移除 +spacy==3.8.14 # unstructured 可能依赖 diff --git a/frontend/src/config.py b/frontend/src/config.py index da4c995..7b34700 100644 --- a/frontend/src/config.py +++ b/frontend/src/config.py @@ -57,6 +57,13 @@ class FrontendConfig: api_url = os.getenv("API_URL", "http://127.0.0.1:8079") self.api_base = api_url.replace("/chat", "").rstrip("/") + # 日志配置 + self.log_level = os.getenv("LOG_LEVEL", "INFO").upper() + self.debug = os.getenv("DEBUG", "false").lower() == "true" + + # 日志配置 + self.log_level = os.getenv("LOG_LEVEL", "INFO").upper() + self.debug = os.getenv("DEBUG", "false").lower() == "true" # 全局配置实例(单例模式) config = FrontendConfig() \ No newline at end of file diff --git a/frontend/src/logger.py b/frontend/src/logger.py index 1f3aefc..f5cacec 100644 --- a/frontend/src/logger.py +++ b/frontend/src/logger.py @@ -4,6 +4,7 @@ """ import os +from .config import config import logging from typing import Any from dotenv import load_dotenv @@ -14,10 +15,10 @@ load_dotenv() # ==================== 日志配置 ==================== # 从环境变量读取日志级别,默认 INFO -LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper() +LOG_LEVEL = config.log_level # 根据环境变量控制是否显示详细调试信息 -DEBUG_MODE = os.getenv("DEBUG", "false").lower() == "true" +DEBUG_MODE = config.debug # 创建统一的日志器 logger = logging.getLogger("ai_agent_frontend") diff --git a/rag_indexer/config.py b/rag_indexer/config.py new file mode 100644 index 0000000..218a431 --- /dev/null +++ b/rag_indexer/config.py @@ -0,0 +1,32 @@ +""" +RAG Indexer 配置管理模块 +集中管理所有环境变量配置项,避免散落在各个文件中 +""" + +import os + +# 尝试从 rag_core 导入配置(如果可用) +try: + from rag_core.config import ( + QDRANT_URL, + QDRANT_API_KEY, + LLAMACPP_EMBEDDING_URL, + LLAMACPP_API_KEY, + DB_URI, + DOCSTORE_URI, + ) +except ImportError: + # 如果 rag_core 不可用,则直接读取环境变量 + QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333") + QDRANT_API_KEY = os.getenv("QDRANT_API_KEY", "") + LLAMACPP_EMBEDDING_URL = os.getenv("LLAMACPP_EMBEDDING_URL", "http://127.0.0.1:8082") + LLAMACPP_API_KEY = os.getenv("LLAMACPP_API_KEY", "") + DB_URI = os.getenv( + "DB_URI", + "postgresql://postgres:huang1998@ai-postgres:5432/langgraph_db?sslmode=disable" + ) + DOCSTORE_URI = os.getenv("DOCSTORE_URI", DB_URI) + +# ========== 索引器专用配置 ========== +# 默认索引存储路径 +INDEX_STORAGE_PATH = os.getenv("INDEX_STORAGE_PATH", "./index_storage") \ No newline at end of file