diff --git a/.gitea/workflows/deploy.yml b/.gitea/workflows/deploy.yml index 09dfc5e..655834c 100644 --- a/.gitea/workflows/deploy.yml +++ b/.gitea/workflows/deploy.yml @@ -67,10 +67,10 @@ jobs: - name: 健康检查 run: | echo "等待后端服务启动..." - sleep 15 + sleep 30 for i in {1..10}; do - # 修正端口为 8083(与 compose 暴露端口一致) - if curl -f http://172.17.0.1:8083/health > /dev/null 2>&1; then + # 修正端口为 8079 + if curl -f http://172.17.0.1:8079/health > /dev/null 2>&1; then echo "✅ 后端服务正常" exit 0 fi diff --git a/app/__init__.py b/app/__init__.py deleted file mode 100644 index 2d07ab9..0000000 --- a/app/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -""" -AI Agent 应用模块 -""" - -from app.agent import AIAgentService -from app.graph.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME - -__all__ = ["AIAgentService", "AVAILABLE_TOOLS", "TOOLS_BY_NAME"] diff --git a/app/agent/__init__.py b/app/agent/__init__.py deleted file mode 100644 index 055f494..0000000 --- a/app/agent/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -""" -Agent 子模块 -""" - -from app.agent.service import AIAgentService - -__all__ = ["AIAgentService"] diff --git a/app/agent/history.py b/app/agent/history.py deleted file mode 100644 index 09f7124..0000000 --- a/app/agent/history.py +++ /dev/null @@ -1,185 +0,0 @@ -""" -历史对话查询模块 -利用 LangGraph 的 checkpointer 获取对话历史和摘要 -""" - -from typing import List, Dict, Any -from app.logger import error # 保持兼容,或者替换为 logger - -class ThreadHistoryService: - """线程历史查询服务""" - - def __init__(self, checkpointer): - self.checkpointer = checkpointer - - async def get_user_threads(self, user_id: str, limit: int = 50) -> List[Dict[str, Any]]: - """ - 获取指定用户的所有线程摘要信息 - - Args: - user_id: 用户 ID - limit: 返回数量限制 - - Returns: - 线程列表,每个包含 thread_id, last_updated, summary, message_count - """ - try: - # 查询 checkpoints 表获取用户的线程列表 - async with self.checkpointer.conn.cursor() as cur: - # 在较新的 LangGraph 版本中,AsyncPostgresSaver 创建的 checkpoints 表 - # 没有 created_at 列,而是使用 checkpoint_id 作为时间排序依据。 - # 我们可以直接按 thread_id 去重,并用 checkpoint_id 排序。 - # 另外,用户的 metadata 存储在 metadata JSONB 列中。 - query = """ - SELECT - thread_id, - MAX(checkpoint_id) as last_updated - FROM checkpoints - WHERE metadata->>'user_id' = %s - GROUP BY thread_id - ORDER BY last_updated DESC - LIMIT %s - """ - await cur.execute(query, (user_id, limit)) - rows = await cur.fetchall() - - threads = [] - for row in rows: - thread_id = row['thread_id'] - - # 获取该线程的状态 - state = await self.checkpointer.aget_tuple({"configurable": {"thread_id": thread_id}}) - - if state and hasattr(state, 'checkpoint') and isinstance(state.checkpoint, dict): - messages = state.checkpoint.get("channel_values", {}).get("messages", []) - - if messages: - summary = self._extract_summary(messages) - message_count = len([m for m in messages if hasattr(m, 'type') and m.type in ["human", "ai"]]) - - threads.append({ - "thread_id": thread_id, - # checkpoint_id 是一个类似于 uuid 的字符串,其中可能包含时间戳信息,也可以直接作为唯一标识 - "last_updated": row['last_updated'] if row['last_updated'] else "", - "summary": summary, - "message_count": message_count - }) - - return threads - - except Exception as e: - error(f"获取用户线程列表失败 (user_id={user_id}): {e}") - return [] - - async def get_thread_messages(self, thread_id: str) -> List[Dict[str, str]]: - """ - 获取指定线程的完整消息历史 - - Args: - thread_id: 线程 ID - - Returns: - 消息列表,格式 [{"role": "user/assistant", "content": "..."}] - """ - try: - state = await self.checkpointer.aget_tuple({"configurable": {"thread_id": thread_id}}) - - if state is None: - return [] - - messages = state.checkpoint.get("channel_values", {}).get("messages", []) if hasattr(state, 'checkpoint') and isinstance(state.checkpoint, dict) else [] - - if not messages: - return [] - - # 转换 LangChain 消息对象为字典 - result = [] - for msg in messages: - # 跳过 system 消息 - if hasattr(msg, 'type') and msg.type == "system": - continue - - if hasattr(msg, 'type'): - role = "user" if msg.type == "human" else "assistant" if msg.type == "ai" else msg.type - result.append({ - "role": role, - "content": msg.content - }) - elif isinstance(msg, dict): - role = msg.get("role", msg.get("type", "unknown")) - if role in ["human", "user"]: - role = "user" - elif role in ["ai", "assistant"]: - role = "assistant" - result.append({ - "role": role, - "content": msg.get("content", "") - }) - - return result - - except Exception as e: - error(f"获取线程消息历史失败: {e}") - return [] - - async def get_thread_summary(self, thread_id: str) -> Dict[str, Any]: - """ - 获取线程摘要(用于历史列表展示) - - Args: - thread_id: 线程 ID - - Returns: - 包含摘要信息的字典 - """ - try: - state = await self.checkpointer.aget_tuple({"configurable": {"thread_id": thread_id}}) - - if state is None or not state.values: - return {"thread_id": thread_id, "summary": "空对话", "message_count": 0} - - messages = state.values.get("messages", []) - summary = self._extract_summary(messages) - message_count = len([m for m in messages if hasattr(m, 'type') and m.type in ["human", "ai"]]) - - # 获取最后更新时间 - last_updated = "" - if state.metadata and "created_at" in state.metadata: - last_updated = state.metadata["created_at"].isoformat() - - return { - "thread_id": thread_id, - "summary": summary, - "message_count": message_count, - "last_updated": last_updated - } - - except Exception as e: - error(f"获取线程摘要失败: {e}") - return {"thread_id": thread_id, "summary": "加载失败", "message_count": 0} - - def _extract_summary(self, messages: List) -> str: - """ - 从消息列表中提取摘要 - - 策略: - 1. 如果有 summarize 节点生成的 summary,优先使用 - 2. 否则使用第一条用户消息的前 50 字 - """ - # 查找是否有 summary 字段 - for msg in messages: - if hasattr(msg, 'additional_kwargs') and msg.additional_kwargs.get('summary'): - return msg.additional_kwargs['summary'] - elif isinstance(msg, dict) and msg.get('summary'): - return msg['summary'] - - # 使用第一条用户消息作为摘要 - for msg in messages: - if hasattr(msg, 'type') and msg.type == "human": - content = msg.content - return content[:50] + "..." if len(content) > 50 else content - elif isinstance(msg, dict) and msg.get("role") in ["user", "human"]: - content = msg.get("content", "") - return content[:50] + "..." if len(content) > 50 else content - - return "空对话" \ No newline at end of file diff --git a/app/agent/llm_factory.py b/app/agent/llm_factory.py deleted file mode 100644 index 9a1a22a..0000000 --- a/app/agent/llm_factory.py +++ /dev/null @@ -1,56 +0,0 @@ -# app/llm_factory.py -import os -from langchain_community.chat_models import ChatZhipuAI -from langchain_openai import ChatOpenAI -from pydantic import SecretStr - -class LLMFactory: - @staticmethod - def create_zhipu(): - api_key = os.getenv("ZHIPUAI_API_KEY") - if not api_key: - raise ValueError("ZHIPUAI_API_KEY not set") - return ChatZhipuAI( - model="glm-4.7-flash", - api_key=api_key, - temperature=0.1, - max_tokens=4096, - timeout=120.0, - max_retries=3, - streaming=True, - ) - - @staticmethod - def create_deepseek(): - api_key = os.getenv("DEEPSEEK_API_KEY") - if not api_key: - raise ValueError("DEEPSEEK_API_KEY not set") - return ChatOpenAI( - base_url="https://api.deepseek.com", - api_key=SecretStr(api_key), - model="deepseek-reasoner", - temperature=0.1, - max_tokens=4096, - timeout=60.0, - max_retries=2, - streaming=True, - ) - - @staticmethod - def create_local(): - base_url = os.getenv("VLLM_BASE_URL", "http://127.0.0.1:8081/v1") - return ChatOpenAI( - base_url=base_url, - api_key=SecretStr(os.getenv("LLAMACPP_API_KEY", "token-abc123")), - model="gemma-4-E4B-it", - timeout=60.0, - max_retries=2, - streaming=True, - ) - - # 模型创建器映射 - CREATORS = { - "local": create_local, - "deepseek": create_deepseek, - "zhipu": create_zhipu, - } \ No newline at end of file diff --git a/app/agent/prompts.py b/app/agent/prompts.py deleted file mode 100644 index 8b05050..0000000 --- a/app/agent/prompts.py +++ /dev/null @@ -1,37 +0,0 @@ -# app/prompts.py -from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder - -def create_system_prompt(tools: list = None) -> ChatPromptTemplate: - """ - 创建系统提示模板,可选择动态注入工具描述。 - """ - tools_section = "" - if tools: - tool_descs = [] - for tool in tools: - # 提取工具名称和描述的第一行 - name = getattr(tool, 'name', None) or getattr(tool, '__name__', 'unknown_tool') - desc = (tool.description or "").split('\n')[0] - tool_descs.append(f"- {name}: {desc}") - tools_section = "\n".join(tool_descs) - - system_template = ( - "你是一个个人生活助手和数据分析助手,请使用中文交流。\n\n" - "【用户背景信息】\n" - "以下是对当前用户的已知信息和长期记忆,你必须优先采纳并在回答中体现:\n" - "{memory_context}\n" - "若包含姓名、偏好等个人信息,请自然融入回应(例如称呼名字、提及偏好)。\n\n" - "【可用工具与使用规则】\n" - f"{tools_section}\n" - "工具调用时请直接返回所需参数,无需额外说明。\n\n" - "【回答要求(必须遵守)】\n" - "1. 回答必须简洁、直接。\n" - "2. 如果你认为该问题需要进行深入的推理或思考,请务必将你的思维链或推理过程用 `` 和 `` 标签包裹起来,放在回答的最前面。\n" - "3. 优先利用已知用户信息进行个性化回复。\n" - "4. 若无信息可依,礼貌询问或提供通用帮助。" - ) - - return ChatPromptTemplate.from_messages([ - ("system", system_template), - MessagesPlaceholder(variable_name="messages") - ]) \ No newline at end of file diff --git a/app/agent/rag_initializer.py b/app/agent/rag_initializer.py deleted file mode 100644 index f391b8f..0000000 --- a/app/agent/rag_initializer.py +++ /dev/null @@ -1,23 +0,0 @@ -# app/rag_initializer.py -from app.rag.tools import create_rag_tool_sync -from rag_core import create_parent_retriever -from app.logger import info, warning - -async def init_rag_tool(local_llm_creator): - """初始化 RAG 工具,失败返回 None""" - try: - info("🔄 正在初始化 RAG 检索系统...") - retriever = create_parent_retriever( - collection_name="rag_documents", - search_k=5, - ) - rewrite_llm = local_llm_creator() - rag_tool = create_rag_tool_sync( - retriever, rewrite_llm, - num_queries=3, rerank_top_n=5 - ) - info("✅ RAG 检索工具初始化成功") - return rag_tool - except Exception as e: - warning(f"⚠️ RAG 检索工具初始化失败: {e}") - return None \ No newline at end of file diff --git a/app/agent/service.py b/app/agent/service.py deleted file mode 100644 index 1d6cc14..0000000 --- a/app/agent/service.py +++ /dev/null @@ -1,156 +0,0 @@ -""" -AI Agent 服务类 - 支持多模型动态切换 -接收外部传入的 checkpointer,不负责管理连接生命周期 -""" - -import json -from dotenv import load_dotenv - -# 本地模块 -from app.graph.graph_builder import GraphBuilder, GraphContext -from app.graph.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME -from app.agent.llm_factory import LLMFactory -from app.agent.rag_initializer import init_rag_tool -from app.logger import info, warning -load_dotenv() - -class AIAgentService: - def __init__(self, checkpointer): - self.checkpointer = checkpointer - self.graphs = {} - self.tools = AVAILABLE_TOOLS.copy() - self.tools_by_name = TOOLS_BY_NAME.copy() - - async def initialize(self): - # 1. 初始化 RAG 工具(如果需要) - rag_tool = await init_rag_tool(LLMFactory.create_local) - if rag_tool: - self.tools.append(rag_tool) - self.tools_by_name[rag_tool.name] = rag_tool - - # 2. 构建各模型的 Graph - for name, creator in LLMFactory.CREATORS.items(): - try: - info(f"🔄 初始化模型 '{name}'...") - llm = creator() - builder = GraphBuilder(llm, self.tools, self.tools_by_name).build() - graph = builder.compile(checkpointer=self.checkpointer) - self.graphs[name] = graph - info(f"✅ 模型 '{name}' 初始化成功") - except Exception as e: - warning(f"⚠️ 模型 '{name}' 初始化失败: {e}") - - if not self.graphs: - raise RuntimeError("没有可用的模型") - return self - - async def process_message(self, message: str, thread_id: str, model: str = "local", user_id: str = "default_user") -> dict: - """处理用户消息,返回包含回复、token统计和耗时的字典""" - if model not in self.graphs: - # 回退到第一个可用模型 - available = list(self.graphs.keys()) - if not available: - raise RuntimeError("没有可用的模型") - model = available[0] - warning(f"模型 '{model}' 不可用,已回退到 '{model}'") - - graph = self.graphs[model] - config = { - "configurable": {"thread_id": thread_id}, - "metadata": {"user_id": user_id} - } - input_state = {"messages": [{"role": "user", "content": message}]} - context = GraphContext(user_id=user_id) - - result = await graph.ainvoke(input_state, config=config, context=context) - - reply = result["messages"][-1].content - token_usage = result.get("last_token_usage", {}) - elapsed_time = result.get("last_elapsed_time", 0.0) - - return { - "reply": reply, - "token_usage": token_usage, - "elapsed_time": elapsed_time - } - - def _serialize_value(self, value): - """递归将 LangChain 对象转换为可 JSON 序列化的格式""" - if hasattr(value, 'content'): - msg_type = getattr(value, 'type', 'message') - return { - "role": msg_type, - "content": getattr(value, 'content', ''), - "additional_kwargs": getattr(value, 'additional_kwargs', {}), - "tool_calls": getattr(value, 'tool_calls', []) - } - elif isinstance(value, dict): - return {k: self._serialize_value(v) for k, v in value.items()} - elif isinstance(value, (list, tuple)): - return [self._serialize_value(item) for item in value] - else: - try: - json.dumps(value) - return value - except (TypeError, ValueError): - return str(value) - - async def process_message_stream(self, message: str, thread_id: str, model_name: str, user_id: str = "default_user"): - """流式处理消息,返回异步生成器""" - graph = self.graphs.get(model_name) - if not graph: - raise ValueError(f"模型 '{model_name}' 未找到或未初始化") - - config = { - "configurable": {"thread_id": thread_id}, - "metadata": {"user_id": user_id} - } - input_state = {"messages": [{"role": "user", "content": message}]} - context = GraphContext(user_id=user_id) - - async for chunk in graph.astream( - input_state, - config=config, - context=context, - stream_mode=["messages", "updates", "custom"], - version="v2", - subgraphs=True - ): - chunk_type = chunk["type"] - processed_event = {} - - if chunk_type == "messages": - message_chunk, metadata = chunk["data"] - node_name = metadata.get("langgraph_node", "unknown") - token_content = getattr(message_chunk, 'content', str(message_chunk)) - reasoning_token = "" - if hasattr(message_chunk, 'additional_kwargs'): - reasoning_token = message_chunk.additional_kwargs.get("reasoning_content", "") - - processed_event = { - "type": "llm_token", - "node": node_name, - "token": token_content, - "reasoning_token": reasoning_token, - "metadata": metadata - } - elif chunk_type == "updates": - updates_data = chunk["data"] - serialized_data = self._serialize_value(updates_data) - processed_event = { - "type": "state_update", - "data": serialized_data - } - if "messages" in serialized_data: - processed_event["messages"] = serialized_data["messages"] - elif chunk_type == "custom": - serialized_data = self._serialize_value(chunk["data"]) - processed_event = { - "type": "custom", - "data": serialized_data - } - else: - continue - - if processed_event: - yield processed_event \ No newline at end of file diff --git a/app/backend.py b/app/backend.py deleted file mode 100644 index c5b857a..0000000 --- a/app/backend.py +++ /dev/null @@ -1,222 +0,0 @@ -""" -FastAPI 后端 - 支持动态模型切换,使用 PostgreSQL 持久化记忆 -采用依赖注入模式,优雅管理资源生命周期 -""" - -import os -import uuid -import json -from contextlib import asynccontextmanager - -from dotenv import load_dotenv -from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect, Depends, Request, Query -from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import StreamingResponse -from pydantic import BaseModel -from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver -from app.agent import AIAgentService -from app.agent.history import ThreadHistoryService -from app.logger import info, error - -# 加载 .env 文件 -load_dotenv() - -# PostgreSQL 连接字符串配置 -# 优先级:环境变量 DB_URI > Docker 内部服务名 > 本地开发地址 -DB_URI = os.getenv( - "DB_URI", - "postgresql://postgres:huang1998@ai-postgres:5432/langgraph_db?sslmode=disable" -) - -@asynccontextmanager -async def lifespan(app: FastAPI): - """应用生命周期管理:创建并注入全局服务""" - # 1. 创建数据库连接池并初始化表(仅 checkpointer) - async with AsyncPostgresSaver.from_conn_string(DB_URI) as checkpointer: - await checkpointer.setup() - - # 2. 构建 AI Agent 服务 - agent_service = AIAgentService(checkpointer) - await agent_service.initialize() - - # 3. 创建历史查询服务 - history_service = ThreadHistoryService(checkpointer) - - # 4. 将服务实例存入 app.state - app.state.agent_service = agent_service - app.state.history_service = history_service - - # 应用运行中... - yield - - # 5. 关闭时自动清理数据库连接(async with 负责) - info("🛑 应用关闭,数据库连接池已释放") - -app = FastAPI(lifespan=lifespan) - -# CORS 中间件(允许前端跨域) -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - -# ========== 健康检查端点 ========== -@app.get("/health") -async def health_check(): - """健康检查端点,用于 Docker 和 CI/CD 监控""" - return {"status": "ok", "service": "ai-agent-backend"} - -# ========== Pydantic 模型 ========== -class ChatRequest(BaseModel): - message: str - thread_id: str | None = None - model: str = "zhipu" - user_id: str = "default_user" - -class ChatResponse(BaseModel): - reply: str - thread_id: str - model_used: str - input_tokens: int = 0 - output_tokens: int = 0 - total_tokens: int = 0 - elapsed_time: float = 0.0 - -# ========== 依赖注入函数 ========== -def get_agent_service(request: Request) -> AIAgentService: - """从 app.state 中获取全局 AIAgentService 实例""" - return request.app.state.agent_service - -def get_history_service(request: Request) -> ThreadHistoryService: - """从 app.state 中获取全局 ThreadHistoryService 实例""" - return request.app.state.history_service - -# ========== HTTP 端点 ========== -@app.post("/chat", response_model=ChatResponse) -async def chat_endpoint( - request: ChatRequest, - agent_service: AIAgentService = Depends(get_agent_service) -): - """同步对话接口,支持模型选择""" - if not request.message: - raise HTTPException(status_code=400, detail="message required") - - thread_id = request.thread_id or str(uuid.uuid4()) - result = await agent_service.process_message( - request.message, thread_id, request.model, request.user_id - ) - - # 提取 token 统计信息 - token_usage = result.get("token_usage", {}) - input_tokens = token_usage.get('prompt_tokens', token_usage.get('input_tokens', 0)) - output_tokens = token_usage.get('completion_tokens', token_usage.get('output_tokens', 0)) - elapsed_time = result.get("elapsed_time", 0.0) - - actual_model = request.model if request.model in agent_service.graphs else next(iter(agent_service.graphs.keys())) - - return ChatResponse( - reply=result["reply"], - thread_id=thread_id, - model_used=actual_model, - input_tokens=input_tokens, - output_tokens=output_tokens, - total_tokens=input_tokens + output_tokens, - elapsed_time=elapsed_time - ) - -# ========== 历史查询接口 ========== -@app.get("/threads") -async def list_threads( - user_id: str = Query("default_user", description="用户 ID"), - limit: int = Query(50, ge=1, le=200, description="返回数量限制"), - history_service: ThreadHistoryService = Depends(get_history_service) -): - """获取当前用户的对话历史列表""" - threads = await history_service.get_user_threads(user_id, limit) - return {"threads": threads} - -@app.get("/thread/{thread_id}/messages") -async def get_thread_messages( - thread_id: str, - user_id: str = Query("default_user", description="用户 ID"), - history_service: ThreadHistoryService = Depends(get_history_service) -): - """获取指定线程的完整消息历史""" - messages = await history_service.get_thread_messages(thread_id) - return {"messages": messages} - -@app.get("/thread/{thread_id}/summary") -async def get_thread_summary( - thread_id: str, - user_id: str = Query("default_user", description="用户 ID"), - history_service: ThreadHistoryService = Depends(get_history_service) -): - """获取指定线程的摘要信息""" - summary = await history_service.get_thread_summary(thread_id) - return summary - -# ========== 流式对话接口 ========== -@app.post("/chat/stream") -async def chat_stream_endpoint( - request: ChatRequest, - agent_service: AIAgentService = Depends(get_agent_service) -): - """流式对话接口(SSE)""" - if not request.message: - raise HTTPException(status_code=400, detail="message required") - - thread_id = request.thread_id or str(uuid.uuid4()) - - async def event_generator(): - try: - async for chunk in agent_service.process_message_stream( - request.message, thread_id, request.model, request.user_id - ): - yield f"data: {json.dumps(chunk, ensure_ascii=False)}\n\n" - yield "data: [DONE]\n\n" - except Exception as e: - error(f"流式响应异常: {e}") - yield f"data: {json.dumps({'type': 'error', 'message': str(e)}, ensure_ascii=False)}\n\n" - yield "data: [DONE]\n\n" - - return StreamingResponse( - event_generator(), - media_type="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "X-Accel-Buffering": "no", # 禁用 Nginx 缓冲 - } - ) - -# ========== WebSocket 端点(可选) ========== -@app.websocket("/ws") -async def websocket_endpoint( - websocket: WebSocket, - agent_service: AIAgentService = Depends(get_agent_service) -): - await websocket.accept() - try: - while True: - data = await websocket.receive_json() - message = data.get("message") - thread_id = data.get("thread_id", str(uuid.uuid4())) - model = data.get("model", "zhipu") - user_id = data.get("user_id", "default_user") - if not message: - await websocket.send_json({"error": "missing message"}) - continue - reply = await agent_service.process_message(message, thread_id, model, user_id) - actual_model = model if model in agent_service.graphs else next(iter(agent_service.graphs.keys())) - await websocket.send_json({"reply": reply, "thread_id": thread_id, "model_used": actual_model}) - except WebSocketDisconnect: - pass - -if __name__ == "__main__": - import uvicorn - # 使用环境变量或默认端口 8079(避免与 llama.cpp 的 8081 端口冲突) - port = int(os.getenv("BACKEND_PORT", "8079")) - uvicorn.run(app, host="0.0.0.0", port=port) diff --git a/app/config.py b/app/config.py deleted file mode 100644 index 77b2c7b..0000000 --- a/app/config.py +++ /dev/null @@ -1,30 +0,0 @@ -""" -环境变量集中管理模块 -所有配置项统一定义,避免散落在各个文件中 -""" - -import os - - -# ========== Graph 执行追踪配置 ========== -# 是否启用 Graph 流转追踪(通过环境变量控制) -ENABLE_GRAPH_TRACE = os.getenv("ENABLE_GRAPH_TRACE", "true").lower() == "true" - -# ========== 记忆提取配置 ========== -# 记忆提取间隔:每 N 轮对话生成一次摘要 -MEMORY_SUMMARIZE_INTERVAL = int(os.getenv("MEMORY_SUMMARIZE_INTERVAL", "10")) - -# ========== Mem0 记忆层配置 ========== -# Qdrant 向量数据库地址 -QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333") -QDRANT_COLLECTION_NAME = os.getenv("QDRANT_COLLECTION_NAME", "mem0_user_memories") -QDRANT_API_KEY = os.getenv("QDRANT_API_KEY", "your-qdrant-api-key") - -# ========== llm 配置 ========== -# LLM 模型配置 -VLLM_BASE_URL = os.getenv("VLLM_BASE_URL", "http://127.0.0.1:8081/v1") -LLM_API_KEY = os.getenv("LLM_API_KEY", "your-ai-api-key") - -# llama.cpp Embedding 服务地址 (用于 Mem0 的向量化) -LLAMACPP_EMBEDDING_URL = os.getenv("LLAMACPP_EMBEDDING_URL", "http://127.0.0.1:8082/v1") -LLAMACPP_API_KEY = os.getenv("LLAMACPP_API_KEY", "your-llamacpp-api-key") \ No newline at end of file diff --git a/app/graph/__init__.py b/app/graph/__init__.py deleted file mode 100644 index 3b9caf2..0000000 --- a/app/graph/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -""" -Graph 子模块 -""" - -from app.graph.graph_builder import GraphBuilder -from app.graph.state import MessagesState, GraphContext - -__all__ = ["GraphBuilder", "MessagesState", "GraphContext"] diff --git a/app/graph/graph_builder.py b/app/graph/graph_builder.py deleted file mode 100644 index c7f5d99..0000000 --- a/app/graph/graph_builder.py +++ /dev/null @@ -1,83 +0,0 @@ -""" -LangGraph 状态图构建模块 - 精简版,仅负责组装图 -所有节点逻辑已拆分到独立模块 -""" - -from langchain_core.language_models import BaseLLM -from langgraph.graph import StateGraph, START, END -from app.graph.state import MessagesState, GraphContext -from app.nodes import ( - should_continue, - create_llm_call_node, - create_tool_call_node, - create_retrieve_memory_node, - create_summarize_node, - finalize_node, -) -from app.nodes.memory_trigger import memory_trigger_node, set_mem0_client -from app.memory import Mem0Client - - -class GraphBuilder: - """LangGraph 状态图构建器 - 仅负责组装图""" - - def __init__(self, llm: BaseLLM, tools: list, tools_by_name: dict): - """ - 初始化构建器 - - Args: - llm: 大语言模型实例 - tools: 工具列表 - tools_by_name: 名称到工具函数的映射 - """ - self.llm = llm - self.tools = tools - self.tools_by_name = tools_by_name - - # ⭐ 创建 Mem0 客户端(懒加载,首次使用时初始化) - self.mem0_client = Mem0Client(llm) - - def build(self) -> StateGraph: - """ - 构建未编译的状态图 - - Returns: - StateGraph 实例 - """ - # 注入全局客户端 - set_mem0_client(self.mem0_client) - - builder = StateGraph(MessagesState, context_schema=GraphContext) - - # ⭐ 通过工厂函数创建节点(依赖注入) - retrieve_memory_node = create_retrieve_memory_node(self.mem0_client) - llm_call_node = create_llm_call_node(self.llm, self.tools) - tool_call_node = create_tool_call_node(self.tools_by_name) - summarize_node = create_summarize_node(self.mem0_client) - - # 添加节点 - builder.add_node("retrieve_memory", retrieve_memory_node) - builder.add_node("memory_trigger", memory_trigger_node) - builder.add_node("llm_call", llm_call_node) - builder.add_node("tool_node", tool_call_node) - builder.add_node("summarize", summarize_node) - builder.add_node("finalize", finalize_node) - - # 添加边 - builder.add_edge(START, "retrieve_memory") - builder.add_edge("retrieve_memory", "memory_trigger") - builder.add_edge("memory_trigger", "llm_call") - builder.add_conditional_edges( - "llm_call", - should_continue, - { - "tool_node": "tool_node", - "summarize": "summarize", - "finalize": "finalize" - } - ) - builder.add_edge("tool_node", "llm_call") - builder.add_edge("summarize", "finalize") - builder.add_edge("finalize", END) - - return builder \ No newline at end of file diff --git a/app/graph/graph_tools.py b/app/graph/graph_tools.py deleted file mode 100644 index 1cc1e17..0000000 --- a/app/graph/graph_tools.py +++ /dev/null @@ -1,95 +0,0 @@ -""" -工具定义模块 - 纯函数工具,无依赖 AIAgent 类 -""" - -# 标准库 -from pathlib import Path - -# 第三方库 -import pandas as pd -import pypdf -import requests -from bs4 import BeautifulSoup -from langchain_core.tools import tool - -def _file_allow_check(filename: str) -> Path: - """检查用户文件名是否位于允许目录 './user_docs' 下,防止路径遍历攻击。""" - allowed_dir = Path("./user_docs").resolve() - allowed_dir.mkdir(exist_ok=True) - - file_path = (allowed_dir / filename).resolve() - if not str(file_path).startswith(str(allowed_dir)): - raise ValueError("错误:非法文件路径。") - - if not file_path.exists(): - raise FileNotFoundError(f"错误:文件 '{filename}' 不存在。") - - return file_path - -@tool -def get_current_temperature(location: str) -> str: - """获取指定地点的当前温度。""" - return f'当前{location}的温度为25℃' - -@tool -def read_local_file(filename: str) -> str: - """读取用户指定名称的本地文本文件内容并返回摘要。""" - try: - file_path = _file_allow_check(filename) - with open(file_path, 'r', encoding='utf-8') as f: - content = f.read() - return f"文件 '{filename}' 的内容开头:\n{content[:1000]}..." - except Exception as e: - return f"读取文件时出错:{str(e)}" - -@tool -def read_pdf_summary(filename: str) -> str: - """读取PDF文件并返回内容文本摘要。""" - try: - file_path = _file_allow_check(filename) - text = "" - with open(file_path, 'rb') as f: - reader = pypdf.PdfReader(f) - for page in reader.pages[:3]: - text += page.extract_text() - return f"PDF文件 '{filename}' 的前几页内容:\n{text[:2000]}..." - except Exception as e: - return f"读取PDF出错:{e}" - -@tool -def read_excel_as_markdown(filename: str) -> str: - """读取Excel文件,并将其主要数据转换为Markdown表格格式。""" - try: - file_path = _file_allow_check(filename) - df = pd.read_excel(file_path) - markdown_table = df.head(10).to_markdown(index=False) - return f"Excel文件 '{filename}' 的数据预览(前10行):\n{markdown_table}" - except Exception as e: - return f"读取Excel出错:{e}" - -@tool -def fetch_webpage_content(url: str) -> str: - """抓取给定URL的网页正文内容,并返回清晰的纯文本。""" - try: - response = requests.get(url, timeout=10) - response.raise_for_status() - soup = BeautifulSoup(response.text, 'html.parser') - for script in soup(["script", "style"]): - script.decompose() - text = soup.get_text() - lines = (line.strip() for line in text.splitlines()) - chunks = (phrase.strip() for line in lines for phrase in line.split(" ")) - text = '\n'.join(chunk for chunk in chunks if chunk) - return f"成功抓取网页 {url},正文内容开头:\n{text[:1500]}..." - except Exception as e: - return f"抓取网页时出错:{str(e)}" - -# 工具列表和映射(全局常量) -AVAILABLE_TOOLS = [ - get_current_temperature, - read_local_file, - fetch_webpage_content, - read_pdf_summary, - read_excel_as_markdown -] -TOOLS_BY_NAME = {tool.name: tool for tool in AVAILABLE_TOOLS} diff --git a/app/graph/retrieve_memory.py b/app/graph/retrieve_memory.py deleted file mode 100644 index ef419d3..0000000 --- a/app/graph/retrieve_memory.py +++ /dev/null @@ -1,76 +0,0 @@ -""" -记忆检索节点模块 -负责从 Mem0 检索相关长期记忆 -""" - -from typing import Any, Dict - -# 本地模块 -from app.graph.state import MessagesState -from app.memory.mem0_client import Mem0Client -from app.utils.logging import log_state_change -from app.logger import debug - -def create_retrieve_memory_node(mem0_client: Mem0Client): - """ - 工厂函数:创建记忆检索节点 - - Args: - mem0_client: Mem0 客户端实例 - - Returns: - 异步节点函数 - """ - - from langchain_core.runnables.config import RunnableConfig - - async def retrieve_memory(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]: - """ - 记忆检索节点 - 使用 Mem0 - - Args: - state: 当前对话状态 - config: 运行时配置 - - Returns: - 包含 memory_context 的状态更新 - """ - log_state_change("retrieve_memory", state, "进入") - - # 从 metadata 中获取 user_id - user_id = config.get("metadata", {}).get("user_id", "default_user") - - # 兼容 dict 和对象两种消息格式 - last_msg = state["messages"][-1] - if isinstance(last_msg, dict): - query = str(last_msg.get("content", "")) - else: - query = str(last_msg.content) - memory_text_parts = [] - - # 确保 Mem0 已初始化(懒加载) - if not mem0_client._initialized: - await mem0_client.initialize() - - if mem0_client.mem0: - try: - # 异步调用 Mem0 语义检索 - facts = await mem0_client.search_memories(query, user_id=user_id, limit=5) - - if facts: - memory_text_parts.append(f"【相关长期记忆】\n" + "\n".join(f"- {f}" for f in facts)) - else: - debug("🔍 [记忆检索] 未找到相关记忆") - except Exception as e: - from app.logger import warning - warning(f"⚠️ Mem0 检索失败: {e}") - else: - from app.logger import warning - warning("⚠️ Mem0 未初始化,跳过记忆检索") - - memory_context = "\n\n".join(memory_text_parts) if memory_text_parts else "暂无用户信息" - result = {"memory_context": memory_context} - log_state_change("retrieve_memory", {**state, **result}, "离开") - return result - - return retrieve_memory diff --git a/app/graph/state.py b/app/graph/state.py deleted file mode 100644 index 2fd214e..0000000 --- a/app/graph/state.py +++ /dev/null @@ -1,25 +0,0 @@ -""" -LangGraph 状态定义模块 -包含 MessagesState 和 GraphContext -""" - -import operator -from typing import Annotated -from typing_extensions import TypedDict -from dataclasses import dataclass -from langchain_core.messages import AnyMessage - -class MessagesState(TypedDict): - """对话状态类型定义""" - messages: Annotated[list[AnyMessage], operator.add] - llm_calls: int - memory_context: str - last_token_usage: dict # 本次调用的 token 使用详情 - last_elapsed_time: float # 本次调用耗时(秒) - turns_since_last_summary: int # 距离上次生成摘要的轮数 - -@dataclass -class GraphContext: - """图执行上下文""" - user_id: str - # 可扩展更多上下文信息 diff --git a/app/logger.py b/app/logger.py deleted file mode 100644 index a4cf4c7..0000000 --- a/app/logger.py +++ /dev/null @@ -1,55 +0,0 @@ -""" -统一的日志模块 - 基于环境变量控制日志级别 -类似 C# 的条件编译效果,开发时打印详细调试信息,生产环境只输出关键信息 -""" - -import os -import logging -from typing import Any -from dotenv import load_dotenv - -# 先加载环境变量 -load_dotenv() - -# 从环境变量读取日志级别,默认 INFO -LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper() - -# 根据环境变量控制是否显示详细调试信息 -DEBUG_MODE = os.getenv("DEBUG", "false").lower() == "true" - -# 创建统一的日志器 -logger = logging.getLogger("ai_agent") -logger.setLevel(getattr(logging, LOG_LEVEL, logging.INFO)) - -# 避免重复添加 handler -if not logger.handlers: - handler = logging.StreamHandler() - # 重要:handler 也需要设置级别,否则可能继承根 logger 的级别 - handler.setLevel(getattr(logging, LOG_LEVEL, logging.INFO)) - formatter = logging.Formatter( - fmt="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s", - datefmt="%Y-%m-%d %H:%M:%S" - ) - handler.setFormatter(formatter) - logger.addHandler(handler) - - -def debug(msg: Any, *args, **kwargs): - """调试日志,仅在 DEBUG 环境变量为 true 时打印""" - if DEBUG_MODE: - logger.debug(msg, *args, **kwargs) - - -def info(msg: Any, *args, **kwargs): - """信息日志""" - logger.info(msg, *args, **kwargs) - - -def warning(msg: Any, *args, **kwargs): - """警告日志""" - logger.warning(msg, *args, **kwargs) - - -def error(msg: Any, *args, **kwargs): - """错误日志""" - logger.error(msg, *args, **kwargs) diff --git a/app/memory/__init__.py b/app/memory/__init__.py deleted file mode 100644 index 29117a6..0000000 --- a/app/memory/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -""" -Mem0 记忆层模块 -""" - -from app.memory.mem0_client import Mem0Client - -__all__ = ["Mem0Client"] diff --git a/app/memory/mem0_client.py b/app/memory/mem0_client.py deleted file mode 100644 index b2cf89e..0000000 --- a/app/memory/mem0_client.py +++ /dev/null @@ -1,146 +0,0 @@ -from app.config import LLM_API_KEY -from app.config import VLLM_BASE_URL -import time -""" -Mem0 记忆层客户端封装模块 -负责 Mem0 的初始化、检索和存储 -""" - -import asyncio -from typing import Optional, List, Dict -from mem0 import AsyncMemory - -from app.config import ( - QDRANT_URL,QDRANT_COLLECTION_NAME,QDRANT_API_KEY, - VLLM_BASE_URL, LLM_API_KEY, - LLAMACPP_EMBEDDING_URL, LLAMACPP_API_KEY -) -from app.logger import info, warning, error - -class Mem0Client: - """Mem0 异步客户端封装类""" - - def __init__(self, llm_instance): - """ - 初始化 Mem0 客户端 - - Args: - llm_instance: LangChain LLM 实例(用于事实提取) - """ - self.llm = llm_instance - self.mem0: Optional[AsyncMemory] = None - self._initialized = False - - async def initialize(self): - """异步初始化 Mem0 客户端,并进行实际连接测试""" - if self._initialized: - return - - try: - # Mem0 配置 - config = { - "vector_store": { - "provider": "qdrant", - "config": { - "url": QDRANT_URL, # 直接使用完整 URL - "api_key": QDRANT_API_KEY, - "collection_name": QDRANT_COLLECTION_NAME, - "embedding_model_dims": 1024, - } - }, - "llm": { - "provider": "openai", - "config": { - "model": "LLM_MODEL", - "api_key": LLM_API_KEY, - "openai_base_url": VLLM_BASE_URL, - "temperature": 0.1, - "max_tokens": 2000, - } - }, - "embedder": { - "provider": "openai", - "config": { - "model": "Qwen3-Embedding-0.6B-Q8_0", - "api_key": LLAMACPP_API_KEY, - "openai_base_url": LLAMACPP_EMBEDDING_URL, - }, - }, - "version": "v1.1" - } - - self.mem0 = AsyncMemory.from_config(config) - info("✅ Mem0 配置加载成功,开始连接测试...") - - # 实际连接测试:调用一次 search 确保 Qdrant 和 Embedding 都可达 - await asyncio.wait_for( - self.mem0.search("ping", user_id="test", limit=1), - timeout=60.0 - ) - info("✅ Mem0 实际连接测试成功,初始化完成") - self._initialized = True - - except asyncio.TimeoutError: - error("❌ Mem0 连接测试超时 (10s),请检查 Qdrant 或 Embedding 服务响应") - self.mem0 = None - self._initialized = False - except Exception as e: - error(f"❌ Mem0 初始化或连接测试失败: {e}") - import traceback - error(f"详细错误信息:\n{traceback.format_exc()}") - self.mem0 = None - self._initialized = False - - async def search_memories(self, query: str, user_id: str, limit: int = 5) -> List[str]: - """ - 检索相关记忆 - - Args: - query: 查询文本 - user_id: 用户 ID - limit: 返回结果数量限制 - - Returns: - List[str]: 记忆事实列表 - """ - if not self.mem0: - warning("⚠️ Mem0 未初始化,跳过记忆检索") - return [] - - try: - memories = await asyncio.wait_for( - self.mem0.search(query, user_id=user_id, limit=limit), - timeout=30.0 - ) - - if memories and "results" in memories: - facts = [m["memory"] for m in memories["results"] if m.get("memory")] - if facts: - info(f"🔍 [记忆检索] Mem0 返回 {len(facts)} 条记忆") - return facts - - info("🔍 [记忆检索] 未找到相关记忆") - return [] - - except asyncio.TimeoutError: - warning("⚠️ Mem0 检索超时 (30s),跳过本次记忆检索") - return [] - except Exception as e: - warning(f"⚠️ Mem0 检索失败: {e}") - return [] - - async def add_memories(self, messages, user_id): - if not self.mem0: - return False - try: - start = time.time() - info(f"📝 开始 Mem0 add,消息数: {len(messages)}") - await asyncio.wait_for( - self.mem0.add(messages, user_id=user_id, metadata={"type": "conversation"}), - timeout=60.0 - ) - info(f"✅ Mem0 add 完成,耗时: {time.time() - start:.2f}s") - return True - except asyncio.TimeoutError: - error(f"❌ Mem0 记忆添加超时 (60s),已等待 {time.time() - start:.2f}s") - return False \ No newline at end of file diff --git a/app/nodes/__init__.py b/app/nodes/__init__.py deleted file mode 100644 index d9eb644..0000000 --- a/app/nodes/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -""" -节点模块 - 导出所有 LangGraph 节点函数 -""" - -from app.nodes.router import should_continue -from app.nodes.llm_call import create_llm_call_node -from app.nodes.tool_call import create_tool_call_node -from app.graph.retrieve_memory import create_retrieve_memory_node -from app.nodes.summarize import create_summarize_node -from app.nodes.finalize import finalize_node - -__all__ = [ - "should_continue", - "create_llm_call_node", - "create_tool_call_node", - "create_retrieve_memory_node", - "create_summarize_node", - "finalize_node", -] diff --git a/app/nodes/finalize.py b/app/nodes/finalize.py deleted file mode 100644 index 87bd746..0000000 --- a/app/nodes/finalize.py +++ /dev/null @@ -1,45 +0,0 @@ -""" -完成事件节点模块 -负责发送完成事件,包含token使用情况和耗时信息 -""" - -from typing import Any, Dict -from langgraph.config import get_stream_writer - -# 本地模块 -from app.graph.state import MessagesState -from app.utils.logging import log_state_change -from app.logger import info, error - -from langchain_core.runnables.config import RunnableConfig - -async def finalize_node(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]: - """ - 完成事件节点 - 发送完成事件,包含token使用情况和耗时信息 - - Args: - state: 当前对话状态 - config: 运行时配置 - - Returns: - 空字典(完成节点,无状态更新) - """ - log_state_change("finalize", state, "进入") - - try: - # 获取流式写入器并发送完成事件 - writer = get_stream_writer() - writer({ - "type": "custom", - "data": { - "type": "done", - "token_usage": state.get("last_token_usage", {}), - "elapsed_time": state.get("last_elapsed_time", 0.0) - } - }) - info("🏁 [完成事件] 已发送完成事件,包含token使用情况和耗时信息") - except Exception as e: - error(f"❌ [完成事件] 发送完成事件时发生异常: {e}") - - log_state_change("finalize", state, "离开") - return {} \ No newline at end of file diff --git a/app/nodes/llm_call.py b/app/nodes/llm_call.py deleted file mode 100644 index f61cd51..0000000 --- a/app/nodes/llm_call.py +++ /dev/null @@ -1,150 +0,0 @@ -""" -LLM 调用节点模块 -负责调用大语言模型并处理响应 -""" - -import time -from typing import Any, Dict -from langchain_core.language_models import BaseLLM -from langchain_core.messages import AIMessage - -# 本地模块 -from app.graph.state import MessagesState -from app.agent.prompts import create_system_prompt -from app.utils.logging import log_state_change -from app.logger import debug, info, error - -def create_llm_call_node(llm: BaseLLM, tools: list): - """ - 工厂函数:创建 LLM 调用节点 - - Args: - llm: LangChain LLM 实例 - tools: 工具列表 - - Returns: - 异步节点函数 - """ - # 构建调用链 - prompt = create_system_prompt(tools) - llm_with_tools = llm.bind_tools(tools) - - # 恢复带 RunnableLambda 的链,并在下方使用 astream 遍历 - chain = prompt | llm_with_tools - - from langchain_core.runnables.config import RunnableConfig - - async def call_llm(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]: - """ - LLM 调用节点(异步方法) - - Args: - state: 当前对话状态 - config: LangChain/LangGraph 自动注入的配置,包含 callbacks 等信息 - - Returns: - 更新后的状态字典 - """ - log_state_change("llm_call", state, "进入") - - memory_context = state.get("memory_context", "暂无用户信息") - start_time = time.time() - - try: - # 恢复为:手动进行 astream,并将所有的 chunk 拼接成最终的 response 返回。 - # LangGraph 会自动监听这期间产生的所有 token。 - chunks = [] - async for chunk in chain.astream( - { - "messages": state["messages"], - "memory_context": memory_context - }, - config=config - ): - chunks.append(chunk) - - # 将所有 chunk 合并成最终的 AIMessage - if chunks: - response = chunks[0] - for chunk in chunks[1:]: - response = response + chunk - else: - response = AIMessage(content="") - - elapsed_time = time.time() - start_time - - # 提取 token 用量(兼容不同 LLM 提供商的元数据格式) - token_usage = {} - input_tokens = 0 - output_tokens = 0 - - # 尝试从 response_metadata 中提取 - if hasattr(response, 'response_metadata') and response.response_metadata: - meta = response.response_metadata - if 'token_usage' in meta: - token_usage = meta['token_usage'] - elif 'usage' in meta: - token_usage = meta['usage'] - - # 尝试从 additional_kwargs 中提取 - if not token_usage and hasattr(response, 'additional_kwargs'): - add_kwargs = response.additional_kwargs - if 'llm_output' in add_kwargs and 'token_usage' in add_kwargs['llm_output']: - token_usage = add_kwargs['llm_output']['token_usage'] - - # 提取具体的 token 数值 - if token_usage: - input_tokens = token_usage.get('prompt_tokens', token_usage.get('input_tokens', 0)) - output_tokens = token_usage.get('completion_tokens', token_usage.get('output_tokens', 0)) - - # 打印 LLM 的完整输出 - debug("\n" + "="*80) - debug("📥 [LLM输出] 大模型返回的完整响应:") - debug(f" 消息类型: {response.type.upper()}") - debug(f" 内容长度: {len(str(response.content))} 字符") - debug("-"*80) - debug(f"{response.content}") - - # 打印响应统计信息 - info(f"⏱️ [LLM统计] 调用耗时: {elapsed_time:.2f}秒") - info(f"📊 [LLM统计] Token用量: 输入={input_tokens}, 输出={output_tokens}, 总计={input_tokens + output_tokens}") - if token_usage: - debug(f"📋 [LLM统计] 详细用量: {token_usage}") - debug("="*80 + "\n") - - result = { - "messages": [response], - "llm_calls": state.get('llm_calls', 0) + 1, - "last_token_usage": token_usage, - "last_elapsed_time": elapsed_time, - "turns_since_last_summary": state.get('turns_since_last_summary', 0) + 1 # 递增计数器 - } - - log_state_change("llm_call", {**state, **result}, "离开") - return result - - except Exception as e: - elapsed_time = time.time() - start_time - error(f"\n❌ [LLM错误] 调用失败 (耗时: {elapsed_time:.2f}秒)") - error(f" 错误类型: {type(e).__name__}") - error(f" 错误信息: {str(e)}") - import traceback - traceback.print_exc() - debug("="*80 + "\n") - - # 返回一个友好的错误消息 - error_response = AIMessage( - content="抱歉,模型暂时无法响应,可能是网络超时或服务繁忙,请稍后再试。" - ) - error_result = { - "messages": [error_response], - "llm_calls": state.get('llm_calls', 0), - "last_token_usage": {}, - "last_elapsed_time": elapsed_time, - "turns_since_last_summary": state.get('turns_since_last_summary', 0) + 1 # 即使出错也递增计数器 - } - - log_state_change("llm_call", state, "离开(异常)") - return error_result - - return call_llm diff --git a/app/nodes/memory_trigger.py b/app/nodes/memory_trigger.py deleted file mode 100644 index 6f02879..0000000 --- a/app/nodes/memory_trigger.py +++ /dev/null @@ -1,38 +0,0 @@ -from typing import Any, Dict -from langchain_core.runnables.config import RunnableConfig -from app.graph.state import MessagesState -from app.memory.mem0_client import Mem0Client -from app.logger import info - -# 全局变量,在 GraphBuilder 中注入 -_mem0_client: Mem0Client = None - -def set_mem0_client(client: Mem0Client): - global _mem0_client - _mem0_client = client - -async def memory_trigger_node(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]: - """检测用户消息中的记忆指令,若命中则主动调用 Mem0 存储""" - if _mem0_client is None: - return {} - - messages = state.get("messages", []) - if not messages: - return {} - - last_msg = messages[-1] - content = last_msg.content if hasattr(last_msg, 'content') else str(last_msg) - - # 触发词(可自行扩展) - trigger_words = ["记住", "记下", "保存", "备忘", "记录下", "别忘了"] - if any(word in content for word in trigger_words): - user_id = config.get("metadata", {}).get("user_id", "default_user") - # 确保 Mem0 已初始化 - if not _mem0_client._initialized: - await _mem0_client.initialize() - # 将用户消息作为事实来源提交给 Mem0 - info(f"📌 检测到记忆指令,已主动触发 Mem0 存储") - mem0_messages = [{"role": "user", "content": content}] - await _mem0_client.add_memories(mem0_messages, user_id=user_id) - - return {} # 不修改状态 \ No newline at end of file diff --git a/app/nodes/router.py b/app/nodes/router.py deleted file mode 100644 index cabc275..0000000 --- a/app/nodes/router.py +++ /dev/null @@ -1,48 +0,0 @@ -""" -路由决策节点 -根据当前状态决定下一步走向 -""" - -from typing import Literal -from langchain_core.messages import AIMessage - -# 本地模块 -from app.config import ENABLE_GRAPH_TRACE, MEMORY_SUMMARIZE_INTERVAL -from app.graph.state import MessagesState -from app.logger import info - - -def should_continue(state: MessagesState) -> Literal['tool_node', 'summarize', 'finalize']: - """ - 决定下一步:工具调用、生成摘要还是结束 - - Args: - state: 当前对话状态 - - Returns: - 下一个节点名称 - """ - last_message = state["messages"][-1] - - # 1. 如果需要调用工具,优先进入工具节点 - if isinstance(last_message, AIMessage) and last_message.tool_calls: - if ENABLE_GRAPH_TRACE: - info(f"🔀 [路由决策] 检测到 {len(last_message.tool_calls)} 个工具调用 → 转向 'tool_node'") - return 'tool_node' - - # 2. 如果是 AI 的最终回复,判断是否达到摘要生成阈值 - if isinstance(last_message, AIMessage): - turns = state.get("turns_since_last_summary", 0) - if turns >= MEMORY_SUMMARIZE_INTERVAL: - if ENABLE_GRAPH_TRACE: - info(f"🔀 [路由决策] 收到 AI 最终回复,已达摘要阈值({turns}/{MEMORY_SUMMARIZE_INTERVAL}) → 转向 'summarize'") - return 'summarize' - else: - if ENABLE_GRAPH_TRACE: - info(f"🔀 [路由决策] 收到 AI 最终回复,未达摘要阈值({turns}/{MEMORY_SUMMARIZE_INTERVAL}) → 结束流程") - return 'finalize' - - # 3. 其他情况(如只有用户消息)直接结束 - if ENABLE_GRAPH_TRACE: - info(f"🔀 [路由决策] 非 AI 消息(如纯用户消息) → 结束流程") - return 'finalize' diff --git a/app/nodes/summarize.py b/app/nodes/summarize.py deleted file mode 100644 index 5c3dd6c..0000000 --- a/app/nodes/summarize.py +++ /dev/null @@ -1,87 +0,0 @@ -""" -记忆存储节点模块 -负责将对话历史提交给 Mem0 进行事实提取和存储 -""" - -from typing import Any, Dict - -# 本地模块 -from app.graph.state import MessagesState -from app.memory.mem0_client import Mem0Client -from app.utils.logging import log_state_change -from app.logger import debug, info, error, warning - -def create_summarize_node(mem0_client: Mem0Client): - """ - 工厂函数:创建记忆存储节点 - - Args: - mem0_client: Mem0 客户端实例 - - Returns: - 异步节点函数 - """ - - from langchain_core.runnables.config import RunnableConfig - - async def summarize_conversation(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]: - """ - 记忆存储节点 - 使用 Mem0 - - Args: - state: 当前对话状态 - config: 运行时配置 - - Returns: - 重置计数器的状态更新 - """ - log_state_change("summarize", state, "进入") - - messages = state["messages"] - if len(messages) < 4: - debug("📝 [记忆添加] 对话过短,跳过") - return {"turns_since_last_summary": 0} - - # 从 metadata 中获取 user_id - user_id = config.get("metadata", {}).get("user_id", "default_user") - - # 确保 Mem0 已初始化(懒加载) - if not mem0_client._initialized: - await mem0_client.initialize() - - # 将整个对话历史转换为 Mem0 需要的消息格式 - mem0_messages = [] - for msg in messages: - # 兼容 dict 和对象两种格式 - if isinstance(msg, dict): - msg_type = msg.get("type", "") - msg_content = msg.get("content", "") - else: - msg_type = getattr(msg, 'type', '') - msg_content = getattr(msg, 'content', '') - - if msg_type == "human": - mem0_messages.append({"role": "user", "content": msg_content}) - elif msg_type == "ai": - mem0_messages.append({"role": "assistant", "content": msg_content}) - elif msg_type == "tool": - mem0_messages.append({"role": "system", "content": f"[工具返回] {msg_content}"}) - - if mem0_client.mem0: - try: - # 异步调用 Mem0 自动提取并存储事实 - success = await mem0_client.add_memories( - mem0_messages, - user_id=user_id - ) - if success: - info(f"📝 [记忆添加] 已提交给 Mem0 进行事实提取") - except Exception as e: - error(f"❌ Mem0 记忆添加失败: {e}") - else: - warning("⚠️ Mem0 未初始化,跳过记忆添加") - - log_state_change("summarize", state, "离开") - return {"turns_since_last_summary": 0} - - return summarize_conversation \ No newline at end of file diff --git a/app/nodes/tool_call.py b/app/nodes/tool_call.py deleted file mode 100644 index 5aa5bdf..0000000 --- a/app/nodes/tool_call.py +++ /dev/null @@ -1,101 +0,0 @@ -""" -工具执行节点模块 -负责执行 AI 调用的工具函数 -""" - -import asyncio -from typing import Any, Dict -from langchain_core.messages import AIMessage, ToolMessage -from langgraph.config import get_stream_writer - -# 本地模块 -from app.graph.state import MessagesState -from app.utils.logging import log_state_change -from app.logger import debug, info - -def create_tool_call_node(tools_by_name: Dict[str, Any]): - """ - 工厂函数:创建工具执行节点 - - Args: - tools_by_name: 名称到工具函数的映射字典 - - Returns: - 异步节点函数 - """ - - from langchain_core.runnables.config import RunnableConfig - - async def call_tools(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]: - """ - 工具执行节点(异步方法) - - Args: - state: 当前对话状态 - config: 运行时配置 - - Returns: - 包含 ToolMessage 的状态更新 - """ - log_state_change("tool_node", state, "进入") - - last_message = state['messages'][-1] - if not isinstance(last_message, AIMessage) or not last_message.tool_calls: - log_state_change("tool_node", state, "离开(无工具调用)") - return {"messages": []} - - results = [] - loop = asyncio.get_event_loop() - - info(f"🛠️ [工具调用] 准备执行 {len(last_message.tool_calls)} 个工具") - - for tool_call in last_message.tool_calls: - tool_name = tool_call["name"] - tool_args = tool_call["args"] - tool_id = tool_call["id"] - tool_func = tools_by_name.get(tool_name) - - debug(f" ├─ 调用工具: {tool_name} 参数: {tool_args}") - - if tool_func is None: - err_msg = f"Tool {tool_name} not found" - debug(f" └─ ❌ {err_msg}") - results.append(ToolMessage(content=err_msg, tool_call_id=tool_id)) - continue - - # 获取流式写入器并发送工具调用开始事件 - writer = get_stream_writer() - writer({"type": "custom", "data": {"type": "tool_start", "tool": tool_name}}) - - try: - # 修复闭包问题:将变量作为默认参数传入 lambda - # 如果工具支持异步 (ainvoke),优先使用异步调用 - if hasattr(tool_func, 'ainvoke'): - observation = await tool_func.ainvoke(tool_args) - else: - observation = await loop.run_in_executor( - None, - lambda args=tool_args: tool_func.invoke(args) # 默认参数捕获当前值 - ) - - # 字符打印 - result_preview = str(observation).replace("\n", " ") - debug(f" └─ ✅ 结果: {result_preview}") - results.append(ToolMessage(content=str(observation), tool_call_id=tool_id)) - - # 发送工具调用完成事件 - writer({"type": "custom", "data": {"type": "tool_end", "tool": tool_name, "success": True}}) - except Exception as e: - debug(f" └─ ❌ 异常: {e}") - results.append(ToolMessage(content=f"Error: {e}", tool_call_id=tool_id)) - - # 发送工具调用失败事件 - writer({"type": "custom", "data": {"type": "tool_end", "tool": tool_name, "success": False, "error": str(e)}}) - - info(f"🛠️ [工具调用] 执行完成,返回 {len(results)} 条 ToolMessage") - - result = {"messages": results} - log_state_change("tool_node", {**state, **result}, "离开") - return result - - return call_tools \ No newline at end of file diff --git a/app/rag/README.md b/app/rag/README.md deleted file mode 100644 index a91d8f6..0000000 --- a/app/rag/README.md +++ /dev/null @@ -1,391 +0,0 @@ -# 在线 RAG 检索与生成系统 (Online RAG Retriever) - -该模块负责 RAG 系统的阶段二:**在线检索与生成**。它接收用户提问,从知识库中检索出上下文,利用各种高级策略去噪、融合,并作为增强上下文输入给大语言模型 (LLM)。 - -## 🎯 核心架构 - -### 技术栈 - -| 组件 | 技术选型 | 版本 | 说明 | -|:-----|:---------|:-----|:-----| -| **基础检索** | `Qdrant` | 1.17+ | HNSW 稠密向量检索 | -| **混合检索** | `Qdrant` + `BM25` | 内置 | 稠密 + 稀疏向量融合 | -| **查询改写** | `LangChain` | 内置 | `MultiQueryGenerator` 多路改写 | -| **RRF 融合** | 自实现 | - | `reciprocal_rank_fusion` 倒数排名融合 | -| **重排序** | `llama.cpp` | 本地服务 | OpenAI 兼容 Rerank API | -| **编排框架** | `asyncio` | Python 3.10+ | 异步并行检索 | - -### 检索流水线 - -``` -┌─────────────────────────────────────────────────────────────┐ -│ 用户提问 │ -└──────────────────────┬──────────────────────────────────────┘ - │ - ▼ -┌─────────────────────────────────────────────────────────────┐ -│ MultiQueryGenerator │ -│ 多路查询改写 (num_queries=3) │ -│ "如何申请项目资金?" → ["项目资金申请流程", "经费申请步骤"] │ -└──────────────────────┬──────────────────────────────────────┘ - │ - ▼ -┌─────────────────────────────────────────────────────────────┐ -│ 并行检索 (asyncio.gather) │ -│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │ -│ │ 查询1 检索 │ │ 查询2 检索 │ │ 查询3 检索 │ │ -│ │ (k=20) │ │ (k=20) │ │ (k=20) │ │ -│ └──────────────┘ └──────────────┘ └──────────────┘ │ -└──────────────────────┬──────────────────────────────────────┘ - │ - ▼ -┌─────────────────────────────────────────────────────────────┐ -│ reciprocal_rank_fusion (RRF) │ -│ RRF_score(d) = Σ 1/(k + rank_q(d)) (k=60) │ -│ 融合多路检索结果,去重排序 │ -└──────────────────────┬──────────────────────────────────────┘ - │ - ▼ -┌─────────────────────────────────────────────────────────────┐ -│ LLaMaCPPReranker │ -│ 远程重排序 (bge-reranker-v2-m3) │ -│ 返回 Top-N (top_n=5) 最相关文档 │ -└──────────────────────┬──────────────────────────────────────┘ - │ - ▼ -┌─────────────────────────────────────────────────────────────┐ -│ 返回增强上下文 │ -│ format_context() → 格式化输出 │ -└─────────────────────────────────────────────────────────────┘ -``` - -### 技术特性 - -- ✅ **多路查询改写**:通过 LLM 将单一问题改写为多个不同角度的查询 -- ✅ **RRF 融合算法**:Reciprocal Rank Fusion,无需评分归一化的融合算法 -- ✅ **远程重排序**:使用 llama.cpp 服务的 OpenAI 兼容 Rerank API -- ✅ **混合检索支持**:稠密向量 + BM25 稀疏向量混合检索 -- ✅ **异步并行检索**:多路查询并行执行,提升检索速度 -- ✅ **优雅降级**:重排序器不可用时自动降级到基础融合结果 - -## 📂 架构与文件结构 - -``` -app/rag/ -├── __init__.py -├── retriever.py # Qdrant 基础检索与混合检索 -├── reranker.py # llama.cpp 远程重排序器 -├── query_transform.py # 多路查询改写生成器 -├── fusion.py # RRF 倒数排名融合算法 -├── pipeline.py # RAG 流水线编排 -└── tools.py # LangChain Tool 封装 -``` - -## 🎯 演进路线与算法详解 (Roadmap) - -### Level 1: 基础向量搜索 (Basic Similarity Search) - -- **核心算法**: 近似最近邻搜索 (ANN, 常用 HNSW 算法)。将用户问题转化为向量后,计算它与库中向量的余弦相似度 (Cosine Similarity),取距离最近的 K 个块。 -- **优缺点**: 速度极快。但只能捕捉"语义相似",如果用户搜索特定专有名词、编号、订单号,纯向量检索往往会失效(产生"幻觉"匹配)。 -- **实现指南**: - - 使用 `rag_indexer.embedders.LlamaCppEmbedder` 作为嵌入模型 - - 使用 `app/rag/retriever.py` 中的 `create_base_retriever` 创建基础检索器 - - 配置 `search_kwargs={"k": 20}` 进行初步召回 - -```python -from app.rag.retriever import create_base_retriever - -retriever = create_base_retriever( - collection_name="rag_documents", - embeddings=embeddings, - search_kwargs={"k": 20} -) -docs = retriever.invoke("什么是 RAG?") -``` - -### Level 2: 混合检索与重排序 (Hybrid Search + Reranker) - -混合检索旨在结合向量的"语义泛化"与关键词的"精准匹配",随后利用重排序模型过滤噪声。 - -**1. 基础召回 (混合检索)** - -- **核心原理**: 结合基于 HNSW 的 Dense Vector 相似度搜索与基于 TF-IDF 的 BM25 稀疏检索 (Sparse Vector)。 -- **实现指南**: 使用 `app/rag/retriever.py` 中的 `create_hybrid_retriever` 函数,配置 `dense_k=10` 和 `sparse_k=10`,总召回 20 条结果。 - -```python -from app.rag.retriever import create_hybrid_retriever - -retriever = create_hybrid_retriever( - collection_name="rag_documents", - embeddings=embeddings, - dense_k=10, - sparse_k=10, - score_threshold=0.3 -) -``` - -**2. 二次精排 (Cross-Encoder)** - -- **核心原理**: 不同于双塔模型(分别算向量再求距离),交叉编码器将"用户问题 + 检索到的单例文档"拼接后整体输入 Transformer 模型,由模型直接输出 0~1 的相关性得分,精度极高。 -- **实现指南**: - - 使用 `app/rag/reranker.py` 中的 `LLaMaCPPReranker` 类,加载 `bge-reranker-v2-m3` 模型 - - 设置 `top_n=5` 保留最相关的 5 条结果 - -```python -from app.rag.reranker import LLaMaCPPReranker - -reranker = LLaMaCPPReranker( - base_url="http://127.0.0.1:8083", - api_key="your-api-key", - top_n=5 -) -sorted_docs = reranker.compress_documents(documents, query) -``` - -### Level 3: RAG-Fusion (多路改写与倒数排名融合) - -RAG-Fusion 通过大模型发散思维,将单一问题改写为多个相似问题,扩大搜索面,再利用数学统计算法合并结果。 - -**1. 多路查询改写** - -- **核心原理**: 克服用户初始提问词不达意或视角受限的问题。 -- **实现指南**: 使用 `app/rag/query_transform.py` 中的 `MultiQueryGenerator` 类,配置 `num_queries=3` 生成 3 个不同角度的查询。 - -```python -from app.rag.query_transform import MultiQueryGenerator - -generator = MultiQueryGenerator(llm=llm, num_queries=3) -queries = await generator.agenerate("如何申请项目资金?") -# 返回:["如何申请项目资金?", "项目资金申请流程是什么?", "申请项目经费需要哪些步骤?"] -``` - -**2. 倒数排名融合 (RRF)** - -- **核心原理**: RRF (Reciprocal Rank Fusion) 是一种无需评分归一化的融合算法。公式为 `RRF_score(d) = Σ 1/(k + rank_q(d))`,有效避免某一极端检索结果主导全局。 -- **实现指南**: 使用 `app/rag/fusion.py` 中的 `reciprocal_rank_fusion` 函数,配置 `k=60` 实现倒数排名融合。 - -```python -from app.rag.fusion import reciprocal_rank_fusion - -# 多个查询的检索结果 -doc_lists = [result1, result2, result3] -fused_docs = reciprocal_rank_fusion(doc_lists, k=60) -``` - -### Level 4: Agentic RAG / Self-RAG (智能体与自我反思) - -- **核心原理**: 基于 LangGraph 的 ReAct (Reasoning and Acting) 状态机路由。大模型并非每次都去死板地执行检索,而是先判断问题:"这是闲聊?还是需要查知识库?"。如果是后者,模型输出一个 `ToolCall` 指令,触发检索。 -- **实现指南**: 使用 `app/rag/tools.py` 中的 `search_knowledge_base` 工具,将其绑定到 LangGraph 状态机中。 - -- **示意图**: - -``` -┌──────────┐ ┌──────────────┐ ┌──────────┐ ┌──────── -│ User │────>│ LangGraph │────>│ RAG_Tool │────>│ Qdrant │ -│ │ │ Agent │ │ │ │ │ -│ "公司报 │ │ 思考: 这是 │ │ ToolCall │ │ RAG- │ -│ 销流程?"│ │ 内部规章问题 │ │ search_ │ │ Fusion │ -│ │ │ 需要查资料 │ │ knowledge│ │ & 混合 │ -│ │<────│ 资料充分, │<────│ 返回最相 │<────│ 检索 │ -│ "根据知 │ │ 开始撰写回答 │ │ 关5条规定 │ │ Cross- │ -│ 识库规定 │ │ │ │ │ │ Encoder│ -│ ..." │ │ │ │ │ │ 重排 │ -└────────── └────────────── └──────────┘ └────────┘ -``` - -### Level 5: GraphRAG 集成 (基于图和关系的 RAG) - -- **核心原理**: 结合知识图谱的结构化关系和向量检索的语义相似度,解决跨文档复杂关系推理问题。 -- **实现指南**: - - 使用 `langchain_community.graphs` 模块构建知识图谱 - - 配置本地大模型(如 `Gemma-4-E4B`)用于实体关系抽取 - - 实现混合检索逻辑,结合向量相似度和图路径分析 - -```python -from langchain_community.graphs import Neo4jGraph -from langchain_experimental.graph_transformers import LLMGraphTransformer - -# 实体关系抽取 -transformer = LLMGraphTransformer(llm=local_llm) -graph_documents = transformer.convert_to_graph_documents(documents) - -# 存储到图数据库 -graph = Neo4jGraph(url="bolt://localhost:7687") -graph.add_graph_documents(graph_documents) -``` - -## 🔧 核心组件详解 - -### 1. 检索器 (retriever.py) - -提供基于 Qdrant 的向量检索能力。 - -**基础检索器**: -```python -from app.rag.retriever import create_base_retriever - -retriever = create_base_retriever( - collection_name="rag_documents", - embeddings=embeddings, - search_kwargs={"k": 20} -) -``` - -**混合检索器**: -```python -from app.rag.retriever import create_hybrid_retriever - -retriever = create_hybrid_retriever( - collection_name="rag_documents", - embeddings=embeddings, - dense_k=10, - sparse_k=10, - score_threshold=0.3 -) -``` - -### 2. 多路查询改写 (query_transform.py) - -通过 LLM 将用户问题改写为多个不同版本,扩大搜索面。 - -```python -from app.rag.query_transform import MultiQueryGenerator - -generator = MultiQueryGenerator(llm=llm, num_queries=3) -queries = await generator.agenerate("如何申请项目资金?") -``` - -### 3. RRF 融合算法 (fusion.py) - -Reciprocal Rank Fusion 算法,公式:`RRF_score(d) = Σ 1/(k + rank_q(d))` - -```python -from app.rag.fusion import reciprocal_rank_fusion - -# 多个查询的检索结果 -doc_lists = [result1, result2, result3] -fused_docs = reciprocal_rank_fusion(doc_lists, k=60) -``` - -### 4. 重排序器 (reranker.py) - -使用 llama.cpp 服务的 OpenAI 兼容 Rerank API 对检索结果重排序。 - -```python -from app.rag.reranker import LLaMaCPPReranker - -reranker = LLaMaCPPReranker( - base_url="http://127.0.0.1:8083", - api_key="your-api-key", - top_n=5 -) -sorted_docs = reranker.compress_documents(documents, query) -``` - -### 5. RAG 流水线 (pipeline.py) - -组合上述组件的完整检索流水线。 - -```python -from app.rag.pipeline import RAGPipeline - -pipeline = RAGPipeline( - retriever=retriever, - llm=llm, - num_queries=3, - rerank_top_n=5, -) - -# 异步检索 -docs = await pipeline.aretrieve("如何申请项目资金?") - -# 格式化上下文 -context = pipeline.format_context(docs) -``` - -## 🔄 与 Agent 系统集成 - -### 封装为 LangChain Tool - -```python -from langchain_core.tools import tool -from app.rag.pipeline import RAGPipeline - -@tool -def search_knowledge_base(query: str) -> str: - """搜索知识库获取相关信息""" - docs = pipeline.retrieve(query) - return pipeline.format_context(docs) -``` - -### 绑定到 LangGraph - -```python -from app.graph.graph_builder import GraphBuilder - -# 将 RAG 工具添加到工具列表 -tools = AVAILABLE_TOOLS + [search_knowledge_base] - -# 构建图 -builder = GraphBuilder(llm, tools, tools_by_name) -graph = builder.build().compile(checkpointer=checkpointer) -``` - -## ⚙️ 环境配置 - -| 变量名 | 说明 | 默认值 | -|:-------|:-----|:-------| -| `QDRANT_URL` | Qdrant 向量数据库地址 | `http://127.0.0.1:6333` | -| `QDRANT_API_KEY` | Qdrant API 密钥 | - | -| `LLAMACPP_RERANKER_URL` | llama.cpp 重排序服务地址 | `http://127.0.0.1:8083` | -| `LLAMACPP_API_KEY` | llama.cpp API 密钥 | - | - -## 🚀 快速开始 - -```python -# 1. 初始化嵌入模型 -from rag_core.embedders import LlamaCppEmbedder -embedder = LlamaCppEmbedder() -embeddings = embedder.as_langchain_embeddings() - -# 2. 创建检索器 -from app.rag.retriever import create_base_retriever -retriever = create_base_retriever( - collection_name="rag_documents", - embeddings=embeddings, - search_kwargs={"k": 20} -) - -# 3. 创建 RAG 流水线 -from app.rag.pipeline import RAGPipeline -pipeline = RAGPipeline( - retriever=retriever, - llm=llm, - num_queries=3, - rerank_top_n=5, -) - -# 4. 执行检索 -docs = pipeline.retrieve("如何申请项目资金?") - -# 5. 格式化上下文 -context = pipeline.format_context(docs) -print(context) -``` - -## 📊 检索策略对比 - -| 策略 | 优点 | 缺点 | 适用场景 | -|:-----|:-----|:-----|:---------| -| **基础向量检索** | 速度快,语义理解好 | 专有名词匹配差 | 通用问答 | -| **混合检索** | 语义 + 关键词匹配 | 需要配置稀疏向量 | 专业术语查询 | -| **多路改写 + RRF** | 搜索面广,结果稳定 | 延迟略高 | 复杂问题 | -| **重排序** | 精度高 | 依赖额外模型 | 最终精排 | - -## 🤝 与 rag_indexer 集成 - -- **向量存储**:共享 Qdrant 集合,确保嵌入模型一致 -- **文档存储**:使用 PostgreSQL 存储父块,通过 UUID 映射 -- **集合名称**:默认使用 `rag_documents` 集合 - -详见 [rag_indexer/README.md](../../rag_indexer/README.md) diff --git a/app/rag/__init__.py b/app/rag/__init__.py deleted file mode 100644 index dca5fed..0000000 --- a/app/rag/__init__.py +++ /dev/null @@ -1,69 +0,0 @@ -""" -RAG 检索与生成模块 - -提供在线检索与生成功能,包括: -- 基础向量检索(稠密向量 / 混合检索) -- 重排序(Cross-Encoder) -- 多路查询改写(Multi-Query) -- RRF 融合(Reciprocal Rank Fusion) -- 完整的 RAG 流水线 -- Agent 工具封装 - -固定流水线: - 用户查询 → 多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档 - -示例用法: - >>> from app.rag.rag import RAGPipeline, create_rag_tool - >>> from rag_indexer.builder import IndexBuilder, IndexBuilderConfig - >>> from langchain_openai import ChatOpenAI - >>> - >>> # 获取基础检索器(如父子块检索器) - >>> config = IndexBuilderConfig(collection_name="my_docs") - >>> builder = IndexBuilder(config) - >>> retriever = builder.retriever - >>> - >>> # 创建 LLM 和流水线 - >>> llm = ChatOpenAI(model="gpt-3.5-turbo") - >>> pipeline = RAGPipeline(retriever=retriever, llm=llm) - >>> - >>> # 检索 - >>> docs = await pipeline.aretrieve("什么是 RAG?") - >>> context = pipeline.format_context(docs) - >>> - >>> # 创建 Agent 工具 - >>> rag_tool = create_rag_tool(retriever=retriever, llm=llm) -""" - -from app.rag.retriever import ( - create_base_retriever, - create_hybrid_retriever, - create_qdrant_client, -) -from app.rag.reranker import LLaMaCPPReranker -from app.rag.query_transform import MultiQueryGenerator -from app.rag.fusion import reciprocal_rank_fusion -from app.rag.pipeline import RAGPipeline -from app.rag.tools import create_rag_tool_sync - - -__all__ = [ - # 检索器工厂函数 - "create_base_retriever", - "create_hybrid_retriever", - "create_qdrant_client", - - # 重排序器 - "LLaMaCPPReranker", - - # 查询改写生成器 - "MultiQueryGenerator", - - # 融合算法 - "reciprocal_rank_fusion", - - # 主流水线 - "RAGPipeline", - - # 工具创建(供 Agent 使用) - "create_rag_tool_sync", -] \ No newline at end of file diff --git a/app/rag/fusion.py b/app/rag/fusion.py deleted file mode 100644 index ddf8f42..0000000 --- a/app/rag/fusion.py +++ /dev/null @@ -1,36 +0,0 @@ -# rag/fusion.py - -from typing import List, Dict -from langchain_core.documents import Document - -def reciprocal_rank_fusion( - doc_lists: List[List[Document]], - k: int = 60 -) -> List[Document]: - """ - 对多个检索结果列表进行 RRF 融合。 - - Args: - doc_lists: 多个检索结果列表,每个列表来自一个查询 - k: RRF 常数,通常设为 60 - - Returns: - 融合后按 RRF 得分降序排列的文档列表 - """ - # 使用文档内容作为唯一标识(如果内容相同但 metadata 不同,视为同一文档) - # 更好的做法是用 docstore 的 ID,这里简化处理:用内容 hash - doc_to_score: Dict[str, float] = {} - doc_map: Dict[str, Document] = {} - - for docs in doc_lists: - for rank, doc in enumerate(docs, start=1): - # 生成唯一标识符(内容+来源组合,避免不同文件相同内容混淆) - doc_id = f"{doc.page_content[:200]}_{doc.metadata.get('source', '')}" - if doc_id not in doc_map: - doc_map[doc_id] = doc - score = doc_to_score.get(doc_id, 0.0) + 1.0 / (k + rank) - doc_to_score[doc_id] = score - - # 按得分排序 - sorted_ids = sorted(doc_to_score.keys(), key=lambda x: doc_to_score[x], reverse=True) - return [doc_map[doc_id] for doc_id in sorted_ids] \ No newline at end of file diff --git a/app/rag/pipeline.py b/app/rag/pipeline.py deleted file mode 100644 index 5adab4a..0000000 --- a/app/rag/pipeline.py +++ /dev/null @@ -1,90 +0,0 @@ -# rag/pipeline.py - -import asyncio -import os -from typing import List -from langchain_core.documents import Document -from langchain_core.language_models import BaseLanguageModel - -from app.rag.reranker import LLaMaCPPReranker -from app.rag.query_transform import MultiQueryGenerator -from app.rag.fusion import reciprocal_rank_fusion - -class RAGPipeline: - """ - 固定流程的 RAG 检索流水线: - 多路改写 → 并行检索 → RRF融合 → 重排序 → 返回父文档 - """ - - def __init__( - self, - retriever, # 基础检索器(应返回父文档,例如 ParentDocumentRetriever 实例) - llm: BaseLanguageModel, - num_queries: int = 3, - rerank_top_n: int = 5, - ): - """ - Args: - retriever: 基础检索器对象,需实现 ainvoke(query) 异步方法 - llm: 用于生成多路查询的语言模型 - num_queries: 生成的查询变体数量 - rerank_top_n: 最终返回的文档数量 - rerank_model: 重排序模型名称 - """ - self.retriever = retriever - self.llm = llm - self.num_queries = num_queries - self.rerank_top_n = rerank_top_n - - # 初始化组件 - self.query_generator = MultiQueryGenerator(llm=llm, num_queries=num_queries) - self.reranker = LLaMaCPPReranker( - base_url=os.getenv("LLAMACPP_RERANKER_URL", "http://127.0.0.1:8083"), - api_key=os.getenv("LLAMACPP_API_KEY", "huang1998"), - top_n=rerank_top_n, - ) - - async def aretrieve(self, query: str) -> List[Document]: - """ - 异步执行完整检索流程 - """ - # Step 1: 生成多路查询 - queries = await self.query_generator.agenerate(query) - # 包含原始查询,确保至少有一条 - if query not in queries: - queries.insert(0, query) - else: - # 如果原始查询已在列表中,将其移至首位 - queries.remove(query) - queries.insert(0, query) - - # Step 2: 并行检索(每个查询获取文档列表) - tasks = [self.retriever.ainvoke(q) for q in queries] - doc_lists = await asyncio.gather(*tasks) - - # Step 3: RRF 融合 - fused_docs = reciprocal_rank_fusion(doc_lists) - - # Step 4: 重排序 - try: - final_docs = self.reranker.compress_documents(fused_docs, query) - except Exception: - # 若重排序器不可用,直接返回融合后的前 N 条 - final_docs = fused_docs[:self.rerank_top_n] - - return final_docs - - def retrieve(self, query: str) -> List[Document]: - """同步检索入口(内部调用异步方法)""" - return asyncio.run(self.aretrieve(query)) - - def format_context(self, documents: List[Document]) -> str: - """将文档列表格式化为上下文字符串""" - if not documents: - return "" - - parts = [] - for i, doc in enumerate(documents, 1): - source = doc.metadata.get("source", "未知来源") - parts.append(f"【资料 {i}】来源:{source}\n{doc.page_content}\n---\n") - return "\n".join(parts) \ No newline at end of file diff --git a/app/rag/query_transform.py b/app/rag/query_transform.py deleted file mode 100644 index 38f9fd1..0000000 --- a/app/rag/query_transform.py +++ /dev/null @@ -1,43 +0,0 @@ -# rag/query_transform.py - -from typing import List -from langchain_core.language_models import BaseLanguageModel -from langchain_core.prompts import PromptTemplate - -MULTI_QUERY_PROMPT = PromptTemplate.from_template( - """你是一个专业的查询改写助手。你的任务是将用户的问题改写成 {num_queries} 个不同的版本。 -这些版本应该从不同的角度、使用不同的关键词来表达相同或相关的意图。 - -原始问题: {question} - -请生成 {num_queries} 个不同版本的查询,每个版本一行。 -确保每个版本都是独立、完整的查询语句。 - -生成 {num_queries} 个查询:""" -) - -class MultiQueryGenerator: - """多路查询生成器(不依赖 LangChain 的 MultiQueryRetriever)""" - - def __init__(self, llm: BaseLanguageModel, num_queries: int = 3): - self.llm = llm - self.num_queries = num_queries - self.prompt = MULTI_QUERY_PROMPT - - def generate(self, query: str) -> List[str]: - """同步生成多个查询变体""" - prompt_str = self.prompt.format(num_queries=self.num_queries, question=query) - response = self.llm.invoke(prompt_str) - # 处理响应内容,按行分割并去除空行和首尾空白 - lines = response.content.strip().split('\n') - queries = [line.strip() for line in lines if line.strip()] - # 确保至少返回原始查询 - return queries[:self.num_queries] if queries else [query] - - async def agenerate(self, query: str) -> List[str]: - """异步生成多个查询变体""" - prompt_str = self.prompt.format(num_queries=self.num_queries, question=query) - response = await self.llm.ainvoke(prompt_str) - lines = response.content.strip().split('\n') - queries = [line.strip() for line in lines if line.strip()] - return queries[:self.num_queries] if queries else [query] \ No newline at end of file diff --git a/app/rag/reranker.py b/app/rag/reranker.py deleted file mode 100644 index 925e283..0000000 --- a/app/rag/reranker.py +++ /dev/null @@ -1,75 +0,0 @@ -""" -重排序器模块 (适配版) -使用远程 llama.cpp 服务 (兼容 OpenAI Rerank API) 替代本地 Cross-Encoder -""" -import requests -from typing import List -from langchain_core.documents import Document - -class LLaMaCPPReranker: - """使用远程 llama.cpp 服务对检索结果重排序。""" - - def __init__(self, - base_url: str, - api_key: str, - top_n: int = 5, - timeout: int = 60): - """ - 初始化远程重排序器 - - Args: - base_url: llama.cpp 服务的地址和端口,默认为环境变量 LLAMACPP_RERANKER_URL 或 "http://127.0.0.1:8083"。 - top_n: 返回前 N 个结果。 - api_key: API 密钥,默认为环境变量 LLAMACPP_API_KEY 或 "huang1998"。 - timeout: 请求超时时间(秒)。 - """ - self.base_url = base_url - self.api_key = api_key - self.top_n = top_n - self.timeout = timeout - self.endpoint = f"{self.base_url}/rerank" - - def compress_documents( - self, documents: List[Document], query: str - ) -> List[Document]: - """ - 对文档进行重排序 - - Args: - documents: 待排序的文档列表 - query: 查询字符串 - - Returns: - 排序后的文档列表 - """ - if not documents: - return [] - - # 准备请求体 - # 根据 llama.cpp 的 OpenAI 兼容性,文档是一个字符串列表 - payload = { - "model": "bge-reranker-v2-m3", - "query": query, - "documents": [doc.page_content for doc in documents], - "top_n": self.top_n - } - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {self.api_key}" - } - - try: - response = requests.post(self.endpoint, json=payload, headers=headers, timeout=self.timeout) - response.raise_for_status() # 检查请求是否成功 - results = response.json() - - # 解析返回结果 - # 返回格式: {"results": [{"index": 0, "document": "...", "relevance_score": 0.8}, ...]} - # 按相关性得分降序排列 - sorted_indices = [item["index"] for item in results["results"]] - sorted_docs = [documents[idx] for idx in sorted_indices] - return sorted_docs - - except Exception as e: - print(f"警告: 远程重排序过程出错,将使用原始排序。错误: {e}") - return documents[:self.top_n] \ No newline at end of file diff --git a/app/rag/retriever.py b/app/rag/retriever.py deleted file mode 100644 index 483c8b9..0000000 --- a/app/rag/retriever.py +++ /dev/null @@ -1,199 +0,0 @@ -""" -Qdrant 向量检索器模块 - -提供基于 Qdrant 的基础向量检索和混合检索(Dense + Sparse)功能。 - -核心原理: -- 基础检索:将查询文本转换为向量,在 Qdrant 中进行近似最近邻(ANN)搜索, - 使用余弦相似度返回最相似的 k 个文档。 -- 混合检索:结合稠密向量检索(语义相似)和 BM25 稀疏向量检索(关键词匹配), - 通过加权或分数融合提高召回精度。 - -使用示例: - >>> from rag_core import LlamaCppEmbedder - >>> embedder = LlamaCppEmbedder() - >>> embeddings = embedder.as_langchain_embeddings() - >>> - >>> # 创建基础检索器 - >>> retriever = create_base_retriever( - ... collection_name="my_docs", - ... embeddings=embeddings, - ... search_kwargs={"k": 10} - ... ) - >>> - >>> # 执行检索 - >>> docs = retriever.invoke("什么是 RAG?") -""" - -from typing import Optional, Dict, Any -from qdrant_client import QdrantClient -from qdrant_client.http.exceptions import UnexpectedResponse -from langchain_qdrant import QdrantVectorStore -from langchain_core.embeddings import Embeddings -from langchain_core.retrievers import BaseRetriever - -from rag_core import QDRANT_URL, QDRANT_API_KEY - -# 模块级常量 -DEFAULT_SEARCH_K = 20 -DEFAULT_SCORE_THRESHOLD = 0.3 - - -def create_qdrant_client( - url: Optional[str] = None, - api_key: Optional[str] = None, - timeout: int = 30, -) -> QdrantClient: - """ - 创建并返回一个配置好的 Qdrant 客户端。 - - 优先使用传入参数,若未提供则回退到环境变量 QDRANT_URL 和 QDRANT_API_KEY。 - - Args: - url: Qdrant 服务地址,例如 "http://localhost:6333"。 - 默认从环境变量 QDRANT_URL 读取。 - api_key: API 密钥(若 Qdrant 启用了认证)。 - 默认从环境变量 QDRANT_API_KEY 读取。 - timeout: 请求超时时间(秒),默认 30 秒。 - - Returns: - 配置好的 QdrantClient 实例。 - - Raises: - ValueError: 如果 url 为空且环境变量也未设置。 - """ - effective_url = url or QDRANT_URL - if not effective_url: - raise ValueError( - "Qdrant URL 未提供,请设置参数 url 或环境变量 QDRANT_URL" - ) - - effective_api_key = api_key or QDRANT_API_KEY - - client_kwargs = { - "url": effective_url, - "timeout": timeout, - } - if effective_api_key: - client_kwargs["api_key"] = effective_api_key - - return QdrantClient(**client_kwargs) - - -def create_base_retriever( - collection_name: str, - embeddings: Embeddings, - search_kwargs: Optional[Dict[str, Any]] = None, - client: Optional[QdrantClient] = None, -) -> BaseRetriever: - """ - 创建基础向量检索器(仅稠密向量检索)。 - - 该检索器使用嵌入模型将查询转为向量,在 Qdrant 集合中执行 ANN 搜索, - 返回语义上最相似的文档块。 - - Args: - collection_name: Qdrant 集合名称(需预先创建并索引)。 - embeddings: LangChain 兼容的嵌入模型实例。 - search_kwargs: 搜索参数,可包含: - - k (int): 返回的文档数量,默认 20。 - - score_threshold (float): 相似度阈值,仅返回高于此分数的文档。 - - filter (dict): Qdrant 过滤条件。 - 若为 None,则使用默认值 {"k": 20}。 - client: 可选的 Qdrant 客户端实例。若未提供,将自动创建。 - - Returns: - BaseRetriever 实例,可直接调用 .invoke(query) 或 .ainvoke(query) 检索。 - - Raises: - ValueError: 如果集合不存在或嵌入模型无效。 - """ - # 合并默认搜索参数 - merged_search_kwargs = {"k": DEFAULT_SEARCH_K} - if search_kwargs: - merged_search_kwargs.update(search_kwargs) - - # 创建或复用 Qdrant 客户端 - if client is None: - client = create_qdrant_client() - - # 验证集合是否存在(可选,便于提前发现问题) - try: - client.get_collection(collection_name) - except UnexpectedResponse as e: - if e.status_code == 404: - raise ValueError( - f"Qdrant 集合 '{collection_name}' 不存在,请先创建并索引文档。" - ) - raise - - # 构建向量存储 - vector_store = QdrantVectorStore( - client=client, - collection_name=collection_name, - embedding=embeddings, - ) - - # 返回检索器 - return vector_store.as_retriever(search_kwargs=merged_search_kwargs) - - -def create_hybrid_retriever( - collection_name: str, - embeddings: Embeddings, - dense_k: int = 10, - sparse_k: int = 10, - score_threshold: Optional[float] = DEFAULT_SCORE_THRESHOLD, - client: Optional[QdrantClient] = None, -) -> BaseRetriever: - """ - 创建混合检索器(稠密向量 + BM25 稀疏向量)。 - - 混合检索结合了语义相似度(Dense)和关键词匹配(Sparse), - 能够更好地处理专有名词、精确匹配等场景。 - - 注意:此功能要求 Qdrant 集合已配置稀疏向量字段并生成了 BM25 索引。 - 若集合未配置稀疏向量,将回退到纯稠密检索(不会报错,但检索效果降级)。 - - Args: - collection_name: Qdrant 集合名称。 - embeddings: 嵌入模型(用于稠密向量)。 - dense_k: 稠密向量检索返回数量,默认 10。 - sparse_k: 稀疏向量检索返回数量,默认 10。 - score_threshold: 相似度阈值,默认 0.3。 - client: 可选的 Qdrant 客户端实例。 - - Returns: - BaseRetriever 实例,配置了混合搜索参数。 - """ - total_k = dense_k + sparse_k - - search_kwargs = { - "k": total_k, - } - if score_threshold is not None: - search_kwargs["score_threshold"] = score_threshold - - # 复用基础检索器创建逻辑,只需调整搜索参数 - return create_base_retriever( - collection_name=collection_name, - embeddings=embeddings, - search_kwargs=search_kwargs, - client=client, - ) - - -# 可选:提供异步友好的辅助函数 -async def acreate_base_retriever( - collection_name: str, - embeddings: Embeddings, - search_kwargs: Optional[Dict[str, Any]] = None, - client: Optional[QdrantClient] = None, -) -> BaseRetriever: - """ - 异步创建基础向量检索器(与同步版本功能相同)。 - - 适用于需要异步初始化的场景(例如在 FastAPI 启动事件中)。 - """ - # 由于 QdrantVectorStore 初始化本身是同步的,这里直接调用同步版本即可 - return create_base_retriever(collection_name, embeddings, search_kwargs, client) \ No newline at end of file diff --git a/app/rag/test.py b/app/rag/test.py deleted file mode 100644 index ff9817a..0000000 --- a/app/rag/test.py +++ /dev/null @@ -1,146 +0,0 @@ -#!/usr/bin/env python3 -""" -RAG 系统使用示例(重构版) - -演示: -1. 使用 IndexBuilder 获取父子块检索器 -2. 创建固定流程的 RAGPipeline(多路改写 → RRF融合 → 重排序 → 返回父文档) -3. 将流水线封装为 LangChain 工具,供 Agent 调用 -""" - -import asyncio -import sys -import os - -from dotenv import load_dotenv - -# 加载环境变量(Qdrant URL、PostgreSQL 连接等) -load_dotenv() - -# 添加项目根目录到路径 -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../..")) - -from rag_indexer.index_builder import IndexBuilderConfig -from rag_indexer.splitters import SplitterType -from app.rag.pipeline import RAGPipeline -from app.rag.tools import create_rag_tool_sync -from pydantic import SecretStr -# 使用本地 LLM(通过 OpenAI 兼容接口) -from langchain_openai import ChatOpenAI -from rag_core.retriever_factory import create_parent_retriever - -load_dotenv() - -def create_llm(): - """创建本地 vLLM 服务 LLM""" - vllm_base_url = os.getenv( - "VLLM_BASE_URL", - "http://127.0.0.1:8081/v1" - ) - - return ChatOpenAI( - base_url=vllm_base_url, - api_key=SecretStr(os.getenv("LLAMACPP_API_KEY", "token-abc123")), - model="gemma-4-E2B-it", - timeout=60.0, # 请求超时时间(秒) - max_retries=2, # 失败后自动重试次数 - streaming=True, # 确保开启流式输出 - ) - -async def demonstrate_full_pipeline(): - """ - 完整流水线演示: - - 从 IndexBuilder 获取 ParentDocumentRetriever - - 创建 RAGPipeline - - 执行检索并打印结果 - """ - print("=" * 60) - print("演示:固定流程 RAG 检索(多路改写 + RRF + 重排序 + 父文档)") - print("=" * 60) - - retriever = create_parent_retriever(collection_name="rag_documents", search_k=5) - - if retriever is None: - print("错误:检索器未初始化,请确保索引已构建。") - return - - # 3. 创建 LLM 用于查询改写 - llm = create_llm() - - # 4. 创建 RAGPipeline(固定流程) - pipeline = RAGPipeline( - retriever=retriever, - llm=llm, - num_queries=3, # 生成 3 个查询变体 - rerank_top_n=5, # 最终返回 5 个父文档 - ) - - # 5. 执行检索 - query = "打虎英雄是谁?" - print(f"\n查询: {query}") - print("-" * 40) - - try: - documents = await pipeline.aretrieve(query) - print(f"返回 {len(documents)} 个父文档\n") - - # 打印结果预览 - for i, doc in enumerate(documents, 1): - content_preview = doc.page_content.replace("\n", " ")[:150] - source = doc.metadata.get("source", "未知来源") - print(f"{i}. 【来源:{source}】") - print(f" {content_preview}...\n") - - # 可选:格式化完整上下文 - # context = pipeline.format_context(documents) - # print(context) - - except Exception as e: - print(f"检索失败: {e}") - import traceback - traceback.print_exc() - -async def demonstrate_tool_creation(): - """ - 演示创建 RAG 工具(供 Agent 使用) - """ - print("\n" + "=" * 60) - print("演示:创建 RAG 工具(供 LangGraph Agent 调用)") - print("=" * 60) - - # 1. 获取检索器(同上) - config = IndexBuilderConfig( - collection_name="rag_documents", - splitter_type=SplitterType.PARENT_CHILD, - ) - retriever = retriever = create_parent_retriever(collection_name="rag_documents", search_k=5) - - # 2. 创建 LLM - llm = create_llm() - - # 3. 创建工具 - rag_tool = create_rag_tool_sync( - retriever=retriever, - llm=llm, - num_queries=3, - rerank_top_n=5, - collection_name="rag_documents", - ) - - print(f"工具名称: {rag_tool.name}") - print(f"工具描述: {rag_tool.description[:100]}...") - - # 4. 模拟 Agent 调用工具 - query = "请告诉我 打虎英雄是谁?" - print(f"\n模拟调用: {query}") - print("-" * 40) - - result = await rag_tool.ainvoke({"query": query}) - print(result[:800] + "..." if len(result) > 800 else result) - -async def main(): - await demonstrate_full_pipeline() - await demonstrate_tool_creation() - -if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file diff --git a/app/rag/tools.py b/app/rag/tools.py deleted file mode 100644 index 2343709..0000000 --- a/app/rag/tools.py +++ /dev/null @@ -1,54 +0,0 @@ -""" -RAG 工具模块 - -将检索功能封装为 LangChain Tool,供 Agent 调用。 -采用固定流水线:多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档。 -""" -from typing import Callable -from langchain_core.tools import tool -from langchain_core.language_models import BaseLanguageModel -from langchain_core.retrievers import BaseRetriever -from app.rag.pipeline import RAGPipeline - -def create_rag_tool_sync( - retriever: BaseRetriever, - llm: BaseLanguageModel, - num_queries: int = 3, - rerank_top_n: int = 5, - collection_name: str = "rag_documents", -) -> Callable: - """ - 创建一个配置好的 RAG 检索工具(同步版本,用于不支持异步的旧版 Agent)。 - - 参数同 create_rag_tool。 - """ - pipeline = RAGPipeline( - retriever=retriever, - llm=llm, - num_queries=num_queries, - rerank_top_n=rerank_top_n, - ) - - @tool - def search_knowledge_base_sync(query: str) -> str: - """在知识库中搜索与查询相关的文档片段(同步版本)。 - - 功能与异步版本相同:多路改写 → RRF融合 → 重排序 → 返回父文档。 - - Args: - query: 用户提出的问题或查询字符串 - - Returns: - 格式化后的相关文档内容。 - """ - try: - documents = pipeline.retrieve(query) # 内部调用异步方法并等待 - if not documents: - return f"在知识库 '{collection_name}' 中未找到与 '{query}' 相关的信息。" - - context = pipeline.format_context(documents) - return context - except Exception as e: - return f"检索过程中发生错误: {str(e)}" - - return search_knowledge_base_sync \ No newline at end of file diff --git a/app/test_backend.py b/app/test_backend.py deleted file mode 100644 index 63e0afb..0000000 --- a/app/test_backend.py +++ /dev/null @@ -1,307 +0,0 @@ -#!/usr/bin/env python3 -""" -完整后端测试 - 验证 Agent 所有功能 -包括:短期记忆、长期记忆、工具调用、流式对话、历史查询 -""" - -import asyncio -import os -import sys -import uuid -from dotenv import load_dotenv - -# 添加项目根目录到 Python 路径 -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) - -load_dotenv() - -from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver -from app.agent import AIAgentService -from app.agent.history import ThreadHistoryService -from app.logger import info, warning, error - -# PostgreSQL 连接字符串 -DB_URI = os.getenv( - "DB_URI", - "postgresql://postgres:***@ai-postgres:5432/langgraph_db?sslmode=disable" -) - -async def print_section(title): - """打印测试区块标题""" - print("\n" + "=" * 70) - print(f" {title}") - print("=" * 70) - -async def test_short_term_memory(agent_service): - """测试短期记忆(同一 thread_id 继续对话)""" - await print_section("测试 1: 短期记忆(Short-term Memory)") - - thread_id = str(uuid.uuid4()) - user_id = "test_user_memory" - - print(f"\n使用 thread_id: {thread_id[:8]}...") - print(f"使用 user_id: {user_id}") - - # 第一轮对话 - print("\n[第一轮] 发送消息: '我叫张三,今年28岁'") - result1 = await agent_service.process_message( - "我叫张三,今年28岁", thread_id, "local", user_id - ) - print(f"回复: {result1['reply'][:100]}...") - - # 第二轮对话 - 测试记忆 - print("\n[第二轮] 发送消息: '我叫什么名字?今年多大?'") - result2 = await agent_service.process_message( - "我叫什么名字?今年多大?", thread_id, "local", user_id - ) - print(f"回复: {result2['reply']}") - - # 验证记忆是否存在 - if "张三" in result2['reply'] or "28" in result2['reply']: - print("\n✅ 短期记忆测试通过!") - return True - else: - print("\n❌ 短期记忆测试失败!") - return False - -async def test_tool_calling(agent_service): - """测试工具调用(RAG 搜索)""" - await print_section("测试 2: 工具调用(Tool Calling)") - - thread_id = str(uuid.uuid4()) - user_id = "test_user_tools" - - print(f"\n使用 thread_id: {thread_id[:8]}...") - print(f"使用 user_id: {user_id}") - - # 发送需要 RAG 搜索的问题 - print("\n发送消息: '请告诉我,打虎英雄是谁?'") - result = await agent_service.process_message( - "请告诉我,打虎英雄是谁?", thread_id, "local", user_id - ) - print(f"回复: {result['reply'][:200]}...") - - # 检查是否调用了 RAG 工具(回复中会有水浒传相关内容) - if "武松" in result['reply'] or "李忠" in result['reply'] or "水浒传" in result['reply']: - print("\n✅ 工具调用测试通过!") - return True - else: - print("\n⚠️ 工具调用测试结果不确定,需要手动验证") - return None - -async def test_streaming(agent_service): - """测试流式对话""" - await print_section("测试 3: 流式对话(Streaming)") - - thread_id = str(uuid.uuid4()) - user_id = "test_user_stream" - - print(f"\n使用 thread_id: {thread_id[:8]}...") - print(f"使用 user_id: {user_id}") - - print("\n发送消息: '用100字介绍一下AI人工智能' (流式)...") - print("流式输出: ", end="", flush=True) - - full_reply = "" - chunk_count = 0 - - try: - async for chunk in agent_service.process_message_stream( - "用100字介绍一下AI人工智能", thread_id, "local", user_id - ): - chunk_count += 1 - if chunk.get("type") == "llm_token": - token = chunk.get("token", "") - print(token, end="", flush=True) - full_reply += token - elif chunk.get("type") == "state_update": - pass # 状态更新不显示 - - print(f"\n\n共收到 {chunk_count} 个 chunk") - print(f"完整回复长度: {len(full_reply)} 字") - - if chunk_count > 0 and len(full_reply) > 10: - print("\n✅ 流式对话测试通过!") - return True - else: - print("\n❌ 流式对话测试失败!") - return False - - except Exception as e: - print(f"\n❌ 流式对话异常: {e}") - return False - -async def test_history_service(agent_service, history_service): - """测试历史查询服务""" - await print_section("测试 4: 历史查询服务(History Service)") - - user_id = "test_user_history" - - # 先创建几个对话 - print(f"\n为 user_id={user_id} 创建测试对话...") - - thread_ids = [] - for i in range(3): - thread_id = str(uuid.uuid4()) - thread_ids.append(thread_id) - - await agent_service.process_message( - f"这是第 {i+1} 个测试对话", thread_id, "local", user_id - ) - print(f" 创建线程 {i+1}: {thread_id[:8]}...") - - # 1. 测试获取用户线程列表 - print("\n[4.1] 测试获取用户线程列表...") - threads = await history_service.get_user_threads(user_id, limit=10) - print(f" 找到 {len(threads)} 个线程") - - if len(threads) >= 3: - print(" ✅ 线程列表查询通过") - else: - print(" ⚠️ 线程数量少于预期") - - # 2. 测试获取单个线程的消息历史 - if thread_ids: - test_thread_id = thread_ids[0] - print(f"\n[4.2] 测试获取线程消息历史 (thread_id={test_thread_id[:8]}...)") - messages = await history_service.get_thread_messages(test_thread_id) - print(f" 找到 {len(messages)} 条消息") - - if len(messages) >= 2: # 至少有一问一答 - print(" ✅ 消息历史查询通过") - else: - print(" ⚠️ 消息数量少于预期") - - # 3. 测试获取线程摘要 - print(f"\n[4.3] 测试获取线程摘要...") - summary = await history_service.get_thread_summary(test_thread_id) - print(f" 摘要: {summary.get('summary', '')[:50]}...") - print(f" 消息数: {summary.get('message_count', 0)}") - - if summary.get('message_count', 0) > 0: - print(" ✅ 线程摘要查询通过") - else: - print(" ⚠️ 摘要查询结果不确定") - - return len(threads) >= 3 - -async def test_long_term_memory(agent_service): - """测试长期记忆(mem0)""" - await print_section("测试 5: 长期记忆(Long-term Memory - mem0)") - - thread_id1 = str(uuid.uuid4()) - thread_id2 = str(uuid.uuid4()) # 不同的线程 - user_id = "test_user_longterm" - - print(f"\n使用 user_id: {user_id}") - print(f"线程 1: {thread_id1[:8]}...") - print(f"线程 2: {thread_id2[:8]}...") - - # 在第一个线程中保存信息 - print("\n[线程 1] 发送消息: '记住,我的宠物名字叫小白,是一只猫'") - result1 = await agent_service.process_message( - "记住,我的宠物名字叫小白,是一只猫", thread_id1, "local", user_id - ) - print(f"回复: {result1['reply'][:100]}...") - - # 等待一下,让 mem0 保存 - await asyncio.sleep(1) - - # 在第二个线程中询问(不同的 thread_id) - print("\n[线程 2] 发送消息: '我的宠物叫什么名字?是什么动物?'") - result2 = await agent_service.process_message( - "我的宠物叫什么名字?是什么动物?", thread_id2, "local", user_id - ) - print(f"回复: {result2['reply']}") - - # 验证长期记忆 - if "小白" in result2['reply'] or "猫" in result2['reply']: - print("\n✅ 长期记忆测试通过!") - return True - else: - print("\n⚠️ 长期记忆可能未启用,或需要手动验证") - return None - -async def main(): - """主测试函数""" - print("\n" + "=" * 70) - print(" 后端完整功能测试") - print("=" * 70) - - results = {} - - try: - # 创建数据库连接和服务 - print("\n正在初始化数据库连接...") - async with AsyncPostgresSaver.from_conn_string(DB_URI) as checkpointer: - await checkpointer.setup() - print("✅ 数据库连接成功") - - # 创建服务实例 - print("\n正在初始化 Agent 服务...") - agent_service = AIAgentService(checkpointer) - await agent_service.initialize() - print("✅ Agent 服务初始化成功") - - history_service = ThreadHistoryService(checkpointer) - print("✅ 历史服务初始化成功") - - print(f"\n可用模型: {list(agent_service.graphs.keys())}") - - # 运行测试 - results["短期记忆"] = await test_short_term_memory(agent_service) - await asyncio.sleep(1) - - results["工具调用"] = await test_tool_calling(agent_service) - await asyncio.sleep(1) - - results["流式对话"] = await test_streaming(agent_service) - await asyncio.sleep(1) - - results["历史查询"] = await test_history_service(agent_service, history_service) - await asyncio.sleep(1) - - results["长期记忆"] = await test_long_term_memory(agent_service) - await asyncio.sleep(1) - - # 打印总结 - await print_section("测试总结") - print("\n测试结果:") - print("-" * 40) - - pass_count = 0 - fail_count = 0 - skip_count = 0 - - for test_name, result in results.items(): - if result is True: - status = "✅ 通过" - pass_count += 1 - elif result is False: - status = "❌ 失败" - fail_count += 1 - else: - status = "⚠️ 待验证" - skip_count += 1 - print(f" {test_name:12s}: {status}") - - print("-" * 40) - print(f"总计: {len(results)} 个测试") - print(f"通过: {pass_count}, 失败: {fail_count}, 待验证: {skip_count}") - - if fail_count == 0: - print("\n🎉 所有核心测试通过!") - else: - print(f"\n⚠️ 有 {fail_count} 个测试失败") - - except Exception as e: - error(f"\n❌ 测试运行异常: {e}") - import traceback - traceback.print_exc() - return 1 - - return 0 if fail_count == 0 else 1 - -if __name__ == "__main__": - exit_code = asyncio.run(main()) - sys.exit(exit_code) diff --git a/app/utils/__init__.py b/app/utils/__init__.py deleted file mode 100644 index faa27e5..0000000 --- a/app/utils/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -""" -工具模块 -""" - -from app.utils.logging import log_state_change, print_llm_input - -__all__ = ["log_state_change", "print_llm_input"] diff --git a/app/utils/logging.py b/app/utils/logging.py deleted file mode 100644 index 8228366..0000000 --- a/app/utils/logging.py +++ /dev/null @@ -1,61 +0,0 @@ -""" -LangGraph 节点日志工具模块 -提供状态流转追踪和 LLM 输入输出打印功能 -""" - -from app.config import ENABLE_GRAPH_TRACE -from app.logger import debug, info - - -def log_state_change(node_name: str, state: dict, prefix: str = "进入"): - """ - 记录状态变化日志 - - Args: - node_name: 节点名称 - state: 当前状态 - prefix: 日志前缀("进入" 或 "离开") - """ - from app.logger import info - - messages = state.get("messages", []) - msg_count = len(messages) - last_msg = messages[-1] if messages else None - last_info = "" - if last_msg: - # 兼容 dict 和对象两种格式 - if isinstance(last_msg, dict): - content_preview = str(last_msg.get("content", ""))[:10].replace("\n", " ") - msg_type = last_msg.get("type", "unknown") - else: - content_preview = str(last_msg.content)[:10].replace("\n", " ") - msg_type = getattr(last_msg, 'type', 'unknown') - last_info = f"{msg_type.upper()}: {content_preview}" - info(f"🔄 [{node_name}] {prefix} | 消息数:{msg_count} | 最后一条:{last_info}") - - -def print_llm_input(prompt_value): - """ - RunnableLambda 回调函数:打印格式化后发送给 LLM 的完整消息 - - Args: - prompt_value: ChatPromptValue 对象,包含格式化后的消息列表 - - Returns: - 原样返回 prompt_value,不影响链式调用 - """ - if not ENABLE_GRAPH_TRACE: - return prompt_value - - messages = prompt_value.messages # ChatPromptValue 提供 .messages 属性 - - debug("\n" + "=" * 80) - debug("📤 [LLM输入] 格式化后发送给大模型的完整消息:") - debug(f" 总消息数: {len(messages)}") - debug("-" * 80) - for i, msg in enumerate(messages): - content_preview = str(msg.content) # 完整输出 - debug(f" [{i}] {msg.type.upper():10s}: {content_preview}") - debug("\n" + "=" * 80 + "\n") - - return prompt_value diff --git a/docker/Dockerfile.frontend b/docker/Dockerfile.frontend deleted file mode 100644 index 4d3f112..0000000 --- a/docker/Dockerfile.frontend +++ /dev/null @@ -1,15 +0,0 @@ -FROM python:3.11-slim - -WORKDIR /app - -COPY requirement.txt . -RUN pip install --no-cache-dir -r requirement.txt - -COPY frontend/ ./frontend/ -COPY app/ ./app/ - -ENV PYTHONPATH=/app - -EXPOSE 8501 - -CMD ["streamlit", "run", "frontend/frontend_main.py", "--server.port", "8501", "--server.address", "0.0.0.0", "--server.baseUrlPath", "/ai"] diff --git a/docker/Dockerfile.backend b/docker/backend/Dockerfile similarity index 75% rename from docker/Dockerfile.backend rename to docker/backend/Dockerfile index c4135be..0caf150 100644 --- a/docker/Dockerfile.backend +++ b/docker/backend/Dockerfile @@ -10,6 +10,7 @@ ENV PYTHONPATH=/app # llama.cpp 服务配置(本地部署标准端口) ENV VLLM_BASE_URL=http://host.docker.internal:18000/v1 ENV LLAMACPP_EMBEDDING_URL=http://host.docker.internal:18001/v1 +ENV LLAMACPP_RERENT_URL=http://host.docker.internal:18002/v1 # Mem0 记忆层配置 ENV QDRANT_COLLECTION_NAME=mem0_user_memories @@ -40,30 +41,21 @@ RUN pip install --no-cache-dir /tmp/models/*.whl && \ RUN pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple # 复制 requirement 并安装(增加超时时间) -COPY requirement.txt . -RUN pip install --no-cache-dir --default-timeout=300 -r requirement.txt +COPY backend/requirements.txt . +RUN pip install --no-cache-dir --default-timeout=300 -r requirements.txt # ============================================================================= -# 预下载 spaCy 语言模型(避免容器启动时重复下载) +# 复制项目代码 # ============================================================================= -RUN pip install --no-cache-dir spacy && \ - python -m spacy download en_core_web_sm && \ - python -m spacy download zh_core_web_sm - -# ============================================================================= -# 复制项目代码 (只复制必需的文件夹,避免依赖被忽略的目录) -# ============================================================================= -COPY rag_core/ ./rag_core/ -COPY app/ ./app/ -COPY frontend/ ./frontend/ -COPY scripts/ ./scripts/ +COPY backend/ ./ # ============================================================================= # 暴露端口 # ============================================================================= -EXPOSE 8083 +EXPOSE 8079 + # ============================================================================= # 启动命令 # ============================================================================= -CMD ["python", "app/backend.py"] \ No newline at end of file +CMD ["python", "app/backend.py"] diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index 74980b7..87a61ed 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -5,7 +5,7 @@ services: backend: build: context: .. # 构建上下文为项目根目录 - dockerfile: docker/Dockerfile.backend + dockerfile: docker/backend/Dockerfile container_name: ai-backend environment: # ⭐ 敏感密钥:通过 .env 注入 @@ -26,7 +26,7 @@ services: - QDRANT_URL=http://115.190.121.151:6333 # 前端通信地址(Docker 内部网络) - - API_URL=http://backend:8083/chat + - API_URL=http://backend:8079/chat volumes: - ../data/user_docs:/app/data/user_docs # 挂载文档目录 - ../logs:/app/logs @@ -35,16 +35,16 @@ services: # ⭐ 移除对 postgres 和 qdrant 的依赖 restart: unless-stopped ports: - - "8083:8083" + - "8079:8079" frontend: build: context: .. - dockerfile: docker/Dockerfile.frontend + dockerfile: docker/frontend/Dockerfile container_name: ai-frontend environment: # Docker 内部网络使用服务名 'backend' 解析后端服务 - - API_URL=http://backend:8083/chat + - API_URL=http://backend:8079/chat ports: - "8501:8501" networks: @@ -56,8 +56,3 @@ services: networks: ai-network: driver: bridge - -# ⭐ PostgreSQL 和 Qdrant 已迁移到远程服务器,不再需要本地卷 -# volumes: -# pg_data: -# qdrant_storage: diff --git a/docker/frontend/Dockerfile b/docker/frontend/Dockerfile new file mode 100644 index 0000000..0b4adc4 --- /dev/null +++ b/docker/frontend/Dockerfile @@ -0,0 +1,21 @@ +FROM python:3.11-slim + +WORKDIR /app + +ENV PYTHONPATH=/app + +# 设置 pip 国内镜像源 +RUN pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple + +# 复制前端依赖并安装 +COPY frontend/requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +# 复制前端代码 +COPY frontend/src/ ./frontend/ + +# 暴露端口 +EXPOSE 8501 + +# 启动命令 +CMD ["streamlit", "run", "frontend/frontend_main.py", "--server.port", "8501", "--server.address", "0.0.0.0", "--server.baseUrlPath", "/ai"] diff --git a/frontend/requirements.txt b/frontend/requirements.txt new file mode 100644 index 0000000..81b7f72 --- /dev/null +++ b/frontend/requirements.txt @@ -0,0 +1,4 @@ +# Frontend - Lightweight dependencies only +streamlit==1.39.0 +requests==2.32.3 +python-dotenv==1.2.2 diff --git a/frontend/__init__.py b/frontend/src/__init__.py similarity index 75% rename from frontend/__init__.py rename to frontend/src/__init__.py index 52f6743..29e32df 100644 --- a/frontend/__init__.py +++ b/frontend/src/__init__.py @@ -3,7 +3,7 @@ AI Agent 前端模块 采用分层架构设计,包含配置、状态、API客户端和UI组件 """ -from frontend.logger import debug, info, warning, error +from .logger import debug, info, warning, error __version__ = "2.0.0" __all__ = ["debug", "info", "warning", "error"] \ No newline at end of file diff --git a/frontend/api_client.py b/frontend/src/api_client.py similarity index 98% rename from frontend/api_client.py rename to frontend/src/api_client.py index 1a55c72..ace6808 100644 --- a/frontend/api_client.py +++ b/frontend/src/api_client.py @@ -7,9 +7,9 @@ import json from typing import List, Dict, Any, Generator import requests -# 使用绝对导入 -from frontend.config import config -from frontend.logger import error, warning +# 使用相对导入 +from .config import config +from .logger import error, warning class APIClient: diff --git a/frontend/components/__init__.py b/frontend/src/components/__init__.py similarity index 100% rename from frontend/components/__init__.py rename to frontend/src/components/__init__.py diff --git a/frontend/components/chat_area.py b/frontend/src/components/chat_area.py similarity index 98% rename from frontend/components/chat_area.py rename to frontend/src/components/chat_area.py index 4571c01..7ffe449 100644 --- a/frontend/components/chat_area.py +++ b/frontend/src/components/chat_area.py @@ -6,10 +6,10 @@ import re import streamlit as st -# 使用绝对导入 -from frontend.state import AppState -from frontend.api_client import api_client -from frontend.config import config +# 使用相对导入 +from ..state import AppState +from ..api_client import api_client +from ..config import config def render_chat_area(): @@ -325,7 +325,7 @@ def _handle_ai_response(): # 消息发送完毕后,静默刷新历史记录列表 # (因为可能生成了新对话,或者旧对话摘要已更新) - from frontend.components.sidebar import _refresh_threads + from .sidebar import _refresh_threads _refresh_threads() # 强制重绘页面,使侧边栏立即显示最新记录 diff --git a/frontend/components/info_panel.py b/frontend/src/components/info_panel.py similarity index 94% rename from frontend/components/info_panel.py rename to frontend/src/components/info_panel.py index 0061551..e163db8 100644 --- a/frontend/components/info_panel.py +++ b/frontend/src/components/info_panel.py @@ -5,8 +5,8 @@ import streamlit as st -# 使用绝对导入 -from frontend.state import AppState +# 使用相对导入 +from ..state import AppState def render_info_panel(): diff --git a/frontend/components/sidebar.py b/frontend/src/components/sidebar.py similarity index 97% rename from frontend/components/sidebar.py rename to frontend/src/components/sidebar.py index 174d62c..40eacbd 100644 --- a/frontend/components/sidebar.py +++ b/frontend/src/components/sidebar.py @@ -6,9 +6,9 @@ import streamlit as st from datetime import datetime -# 使用绝对导入 -from frontend.state import AppState -from frontend.api_client import api_client +# 使用相对导入 +from ..state import AppState +from ..api_client import api_client def render_sidebar(): """渲染左侧栏""" diff --git a/frontend/config.py b/frontend/src/config.py similarity index 100% rename from frontend/config.py rename to frontend/src/config.py diff --git a/frontend/frontend_main.py b/frontend/src/frontend_main.py similarity index 90% rename from frontend/frontend_main.py rename to frontend/src/frontend_main.py index 1a38f89..508b42c 100644 --- a/frontend/frontend_main.py +++ b/frontend/src/frontend_main.py @@ -7,16 +7,17 @@ import sys import os # 添加项目根目录到 Python 路径,支持绝对导入 -sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +# 现在的结构: frontend/src/frontend_main.py,所以要获取 frontend/ 目录作为根 +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) import streamlit as st -# 使用绝对导入 -from frontend.config import config -from frontend.state import AppState -from frontend.components.sidebar import render_sidebar -from frontend.components.chat_area import render_chat_area -from frontend.components.info_panel import render_info_panel +# 使用相对导入 +from .config import config +from .state import AppState +from .components.sidebar import render_sidebar +from .components.chat_area import render_chat_area +from .components.info_panel import render_info_panel # ============================================================================= diff --git a/frontend/logger.py b/frontend/src/logger.py similarity index 100% rename from frontend/logger.py rename to frontend/src/logger.py diff --git a/frontend/state.py b/frontend/src/state.py similarity index 99% rename from frontend/state.py rename to frontend/src/state.py index 10e16b1..e1d32bb 100644 --- a/frontend/state.py +++ b/frontend/src/state.py @@ -7,7 +7,7 @@ import uuid from typing import List, Dict, Any import streamlit as st -from frontend.config import config +from .config import config class AppState: diff --git a/frontend/utils.py b/frontend/src/utils.py similarity index 100% rename from frontend/utils.py rename to frontend/src/utils.py diff --git a/rag_core/__init__.py b/rag_core/__init__.py deleted file mode 100644 index a19afb2..0000000 --- a/rag_core/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -""" -RAG Core - 公共 RAG 组件包 - -提供嵌入模型、向量存储和文档存储的公共功能,被 rag_indexer 和 app/rag 共用。 -""" - -from rag_core.embedders import LlamaCppEmbedder -from rag_core.vector_store import QdrantVectorStore, QDRANT_URL, QDRANT_API_KEY -from rag_core.store import PostgresDocStore, create_docstore -from rag_core.retriever_factory import create_parent_retriever - - -__all__ = [ - "LlamaCppEmbedder", - "QdrantVectorStore", - "QDRANT_URL", - "QDRANT_API_KEY", - "PostgresDocStore", - "create_docstore", - "create_parent_retriever", -] diff --git a/rag_core/client.py b/rag_core/client.py deleted file mode 100644 index 7615ea7..0000000 --- a/rag_core/client.py +++ /dev/null @@ -1,28 +0,0 @@ -# rag_core/client.py -import os -from typing import Optional -from qdrant_client import QdrantClient - - -QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333") -QDRANT_API_KEY = os.getenv("QDRANT_API_KEY") - -def create_qdrant_client( - url: Optional[str] = None, - api_key: Optional[str] = None, - timeout: int = 300, # 索引构建需要较长超时 -) -> QdrantClient: - effective_url = url or QDRANT_URL - effective_api_key = api_key or QDRANT_API_KEY - - if not effective_url: - raise ValueError("Qdrant URL 未配置") - - client_kwargs = { - "url": effective_url, - "timeout": timeout, - } - if effective_api_key: - client_kwargs["api_key"] = effective_api_key - - return QdrantClient(**client_kwargs) \ No newline at end of file diff --git a/rag_core/embedders.py b/rag_core/embedders.py deleted file mode 100644 index 66ffa6e..0000000 --- a/rag_core/embedders.py +++ /dev/null @@ -1,82 +0,0 @@ -""" -嵌入模型包装器,用于 llama.cpp 服务。 -""" - -import os -import httpx -from typing import List, Optional - -from langchain_core.embeddings import Embeddings - -class LlamaCppEmbedder: - """通过 OpenAI 兼容 API 封装 llama.cpp 嵌入服务。""" - - def __init__( - self, - base_url: Optional[str] = None, - api_key: Optional[str] = None, - model: str = "Qwen3-Embedding-0.6B-Q8_0", - ): - self.base_url = base_url or os.getenv("LLAMACPP_EMBEDDING_URL", "http://127.0.0.1:8082") - self.api_key = api_key or os.getenv("LLAMACPP_API_KEY", "") - self.model = model - - def as_langchain_embeddings(self) -> Embeddings: - """创建 LangChain 兼容的嵌入实例。""" - return _LlamaCppLangchainAdapter(self) - - def embed_documents(self, texts: List[str]) -> List[List[float]]: - """嵌入一批文档。""" - return self._call_embedding_api(texts) - - def embed_query(self, text: str) -> List[float]: - """嵌入单个查询。""" - return self._call_embedding_api([text])[0] - - def get_embedding_dimension(self) -> int: - """通过嵌入测试字符串获取嵌入维度。""" - test_embedding = self.embed_query("test") - return len(test_embedding) - - def _call_embedding_api(self, texts: List[str]) -> List[List[float]]: - """直接调用 llama.cpp 嵌入 API。""" - base = self.base_url.rstrip("/") - if not base.endswith("/v1"): - base = base + "/v1" - - headers = {"Content-Type": "application/json"} - if self.api_key: - headers["Authorization"] = f"Bearer {self.api_key}" - - payload = { - "input": texts, - "model": self.model, - } - - with httpx.Client(timeout=120) as client: - response = client.post( - f"{base}/embeddings", - headers=headers, - json=payload, - ) - response.raise_for_status() - data = response.json() - - if isinstance(data, list): - return [item["embedding"] for item in data] - elif isinstance(data, dict) and "data" in data: - return [item["embedding"] for item in sorted(data["data"], key=lambda x: x["index"])] - else: - raise ValueError(f"未知的嵌入 API 响应格式: {data}") - -class _LlamaCppLangchainAdapter(Embeddings): - """将 LlamaCppEmbedder 适配为 LangChain Embeddings 接口。""" - - def __init__(self, embedder: LlamaCppEmbedder): - self._embedder = embedder - - def embed_documents(self, texts: List[str]) -> List[List[float]]: - return self._embedder.embed_documents(texts) - - def embed_query(self, text: str) -> List[float]: - return self._embedder.embed_query(text) \ No newline at end of file diff --git a/rag_core/retriever_factory.py b/rag_core/retriever_factory.py deleted file mode 100644 index 25dab1c..0000000 --- a/rag_core/retriever_factory.py +++ /dev/null @@ -1,59 +0,0 @@ -# rag_core/retriever_factory.py -from langchain_core.embeddings import Embeddings -from langchain_classic.retrievers import ParentDocumentRetriever -from langchain_text_splitters import RecursiveCharacterTextSplitter -from typing import Optional -from langchain_core.embeddings import Embeddings -from langchain_core.stores import BaseStore -from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter -from langchain_classic.retrievers import ParentDocumentRetriever - -from rag_core import LlamaCppEmbedder, QdrantVectorStore, create_docstore - -def create_parent_retriever( - collection_name: str = "rag_documents", - embeddings: Optional[Embeddings] = None, - parent_splitter: Optional[TextSplitter] = None, - child_splitter: Optional[TextSplitter] = None, - docstore: Optional[BaseStore] = None, - search_k: int = 5, - # 若未传入切分器,则用以下参数创建默认切分器 - parent_chunk_size: int = 1000, - parent_chunk_overlap: int = 100, - child_chunk_size: int = 200, - child_chunk_overlap: int = 20, -) -> ParentDocumentRetriever: - # 嵌入模型 - if embeddings is None: - embedder = LlamaCppEmbedder() - embeddings = embedder.as_langchain_embeddings() - - # 向量存储(只读) - vector_store = QdrantVectorStore( - collection_name=collection_name, - embeddings=embeddings, - ) - - # 切分器(若未提供则创建默认) - if parent_splitter is None: - parent_splitter = RecursiveCharacterTextSplitter( - chunk_size=parent_chunk_size, - chunk_overlap=parent_chunk_overlap, - ) - if child_splitter is None: - child_splitter = RecursiveCharacterTextSplitter( - chunk_size=child_chunk_size, - chunk_overlap=child_chunk_overlap, - ) - - # 文档存储 - if docstore is None: - docstore, _ = create_docstore() # 从环境变量读取连接 - - return ParentDocumentRetriever( - vectorstore=vector_store.get_langchain_vectorstore(), - docstore=docstore, - child_splitter=child_splitter, - parent_splitter=parent_splitter, - search_kwargs={"k": search_k}, - ) \ No newline at end of file diff --git a/rag_core/store/__init__.py b/rag_core/store/__init__.py deleted file mode 100644 index b4aab75..0000000 --- a/rag_core/store/__init__.py +++ /dev/null @@ -1,31 +0,0 @@ -""" -文档存储模块 - 用于 ParentDocumentRetriever 的父文档存储。 - -提供 PostgreSQL 存储后端: -- PostgresDocStore: PostgreSQL 数据库存储(生产环境) - -示例用法: - >>> from rag_core.store import create_docstore - - >>> # 创建 PostgreSQL 存储 - >>> store, conn = create_docstore( - ... connection_string="postgresql://user:pass@host:5432/db", - ... table_name="parent_docs" - ... ) -""" - - -from rag_core.store.postgres import PostgresDocStore -from rag_core.store.factory import create_docstore, get_docstore_uri, DEFAULT_DB_URI - -__version__ = "2.0.0" - -__all__ = [ - # 具体实现 - "PostgresDocStore", - - # 工厂函数 - "create_docstore", - "get_docstore_uri", - "DEFAULT_DB_URI", -] diff --git a/rag_core/store/factory.py b/rag_core/store/factory.py deleted file mode 100644 index 43b465e..0000000 --- a/rag_core/store/factory.py +++ /dev/null @@ -1,75 +0,0 @@ -""" -文档存储工厂 - 创建不同类型的存储实例。 - -提供统一的接口来创建本地文件存储或 PostgreSQL 存储。 -""" - -import os -import logging -from typing import Optional, Tuple - -from langchain_core.stores import BaseStore -from rag_core.store.postgres import PostgresDocStore -from dotenv import load_dotenv -load_dotenv() - -logger = logging.getLogger(__name__) - -# 默认连接字符串(从环境变量读取) -DEFAULT_DB_URI = os.getenv( - "DB_URI", - "postgresql://postgres:huang1998@ai-postgres:5432/langgraph_db?sslmode=disable" -) - - -def get_docstore_uri() -> str: - """获取 docstore 专用的数据库连接字符串(可与主库相同)""" - return os.getenv("DOCSTORE_URI", DEFAULT_DB_URI) - - -def create_docstore( - store_type: str = "postgres", - connection_string: Optional[str] = None, - table_name: str = "parent_documents", - pool_config: Optional[dict] = None, - max_concurrency: Optional[int] = None -) -> Tuple[BaseStore, Optional[str]]: - """ - 工厂函数,创建 PostgreSQL 文档存储。 - - Args: - store_type: 存储类型,目前仅支持 "postgres"(默认) - connection_string: PostgreSQL 连接字符串 - table_name: PostgreSQL 表名(默认:parent_documents) - pool_config: 连接池配置 - max_concurrency: 最大并发操作数,如果为 None 则不限制 - - Returns: - 元组 (存储实例, 连接字符串) - - Raises: - ValueError: 不支持的存储类型 - ImportError: 缺少必要的依赖 - - Example: - >>> # 创建 PostgreSQL 存储 - >>> store, conn = create_docstore( - ... connection_string="postgresql://user:pass@host:5432/db", - ... table_name="parent_docs", - ... max_concurrency=10 - ... ) - """ - store_type = store_type.lower() - - if store_type == "postgres": - conn_str = connection_string or get_docstore_uri() - store = PostgresDocStore( - connection_string=conn_str, - table_name=table_name, - pool_config=pool_config, - max_concurrency=max_concurrency - ) - return store, conn_str - - else: - raise ValueError(f"不支持的存储类型: {store_type}。目前仅支持: postgres") diff --git a/rag_core/store/postgres.py b/rag_core/store/postgres.py deleted file mode 100644 index 23b7153..0000000 --- a/rag_core/store/postgres.py +++ /dev/null @@ -1,246 +0,0 @@ -""" -异步 PostgreSQL 存储实现 - 用于生产环境。 - -使用 asyncpg 实现真正的异步 PostgreSQL 文档存储,支持高并发访问。 -""" - -import asyncio -import json -import logging -from typing import List, Dict, Any, Optional, Iterator, Tuple, Sequence - -from langchain_core.documents import Document -from langchain_core.stores import BaseStore - -import asyncpg - -logger = logging.getLogger(__name__) - -class PostgresDocStore(BaseStore[str, Any]): - """ - 异步 PostgreSQL 文档存储实现。 - - 使用 asyncpg 作为异步 PostgreSQL 客户端,支持: - - 真正的异步操作 - - 连接池管理 - - 自动表创建 - - 批量操作(amget/amset/amdelete) - - JSONB 数据存储 - - 并发控制 - - 适用于生产环境,提供高性能的异步数据持久化。 - - Attributes: - dsn: PostgreSQL 连接字符串 - table_name: 存储表名,默认为 "parent_documents" - _pool: asyncpg 连接池实例 - _semaphore: 控制并发数的信号量(可选) - """ - - def __init__( - self, - connection_string: str, - table_name: str = "parent_documents", - pool_config: Optional[Dict[str, Any]] = None, - max_concurrency: Optional[int] = None - ): - """ - 初始化异步 PostgreSQL 文档存储。 - - Args: - connection_string: PostgreSQL 连接 URL,格式: - "postgresql://user:password@host:port/database?sslmode=disable" - table_name: 存储表名,默认为 "parent_documents" - pool_config: 连接池配置字典,包含: - - min_size: 最小连接数(默认 2) - - max_size: 最大连接数(默认 10) - max_concurrency: 最大并发操作数,如果为 None 则不限制 - - Raises: - ImportError: 未安装 asyncpg 时抛出 - - Example: - >>> store = PostgresDocStore( - ... "postgresql://user:pass@localhost:5432/mydb", - ... table_name="parent_docs", - ... pool_config={"min_size": 5, "max_size": 20}, - ... max_concurrency=10 - ... ) - """ - - - self.dsn = connection_string - self.table_name = table_name - self._pool: Optional["asyncpg.Pool"] = None - self._pool_config = pool_config or {} - - # 并发控制信号量 - self._semaphore = None - if max_concurrency is not None and max_concurrency > 0: - self._semaphore = asyncio.Semaphore(max_concurrency) - - # 注意:连接池的异步初始化延迟到第一次使用时 - # 表结构创建也延迟到第一次操作时 - - async def _get_pool(self): - """获取或创建 asyncpg 连接池。""" - if self._pool is None: - import asyncpg - min_size = self._pool_config.get("min_size", 2) - max_size = self._pool_config.get("max_size", 10) - - try: - self._pool = await asyncpg.create_pool( - dsn=self.dsn, - min_size=min_size, - max_size=max_size - ) - logger.info(f"PostgreSQL 异步连接池已创建: {self.table_name}") - - # 初始化表结构 - await self._create_table() - except Exception as e: - raise RuntimeError(f"PostgreSQL 异步连接池创建失败: {e}") - - return self._pool - - async def _create_table(self): - """创建存储表(如果不存在)。""" - pool = await self._get_pool() - async with pool.acquire() as conn: - async with conn.transaction(): - await conn.execute(f""" - CREATE TABLE IF NOT EXISTS {self.table_name} ( - key TEXT PRIMARY KEY, - value JSONB NOT NULL, - created_at TIMESTAMPTZ DEFAULT NOW() - ) - """) - logger.info(f"表 {self.table_name} 已就绪") - - async def _with_concurrency_control(self, coro): - """使用信号量控制并发执行。""" - if self._semaphore is None: - return await coro - async with self._semaphore: - return await coro - - # --- 同步方法(保持兼容性,但功能有限)--- - - def mget(self, keys: Sequence[str]) -> List[Optional[Any]]: - """不支持同步操作,请使用异步 amget 方法。""" - raise NotImplementedError("不支持同步操作,请使用异步 amget 方法。") - - def mset(self, key_value_pairs: Sequence[Tuple[str, Any]]) -> None: - """不支持同步操作,请使用异步 amset 方法。""" - raise NotImplementedError("不支持同步操作,请使用异步 amset 方法。") - - def mdelete(self, keys: Sequence[str]) -> None: - """不支持同步操作,请使用异步 amdelete 方法。""" - raise NotImplementedError("不支持同步操作,请使用异步 amdelete 方法。") - - def yield_keys(self, *, prefix: str | None = None) -> Iterator[str]: - """不支持同步操作,请使用异步 ayield_keys 方法。""" - raise NotImplementedError("不支持同步操作,请使用异步 ayield_keys 方法。") - - # --- 异步方法(真正的实现)--- - - async def amget(self, keys: Sequence[str]) -> List[Optional[Any]]: - """异步批量获取文档。""" - if not keys: - return [] - - async def _amget(): - pool = await self._get_pool() - async with pool.acquire() as conn: - rows = await conn.fetch( - f"SELECT key, value FROM {self.table_name} WHERE key = ANY($1)", - keys - ) - result_map = {} - for row in rows: - val = row['value'] - if isinstance(val, str): - val = json.loads(val) - if isinstance(val, dict) and 'page_content' in val: - result_map[row['key']] = Document(**val) - else: - result_map[row['key']] = val - return [result_map.get(key) for key in keys] - - return await self._with_concurrency_control(_amget()) - - async def amset(self, key_value_pairs: Sequence[Tuple[str, Any]]) -> None: - """异步批量设置文档。""" - if not key_value_pairs: - return - - async def _amset(): - pool = await self._get_pool() - async with pool.acquire() as conn: - async with conn.transaction(): - await conn.executemany( - f""" - INSERT INTO {self.table_name} (key, value) - VALUES ($1, $2) - ON CONFLICT (key) DO UPDATE SET value = EXCLUDED.value - """, - [ - (k, json.dumps(v.dict() if isinstance(v, Document) else v, ensure_ascii=False)) - for k, v in key_value_pairs - ] - ) - logger.debug(f"已异步批量设置 {len(key_value_pairs)} 个文档") - - await self._with_concurrency_control(_amset()) - - async def amdelete(self, keys: Sequence[str]) -> None: - """异步批量删除文档。""" - if not keys: - return - - async def _amdelete(): - pool = await self._get_pool() - async with pool.acquire() as conn: - async with conn.transaction(): - await conn.execute( - f"DELETE FROM {self.table_name} WHERE key = ANY($1)", - keys - ) - logger.debug(f"已异步批量删除 {len(keys)} 个文档") - - await self._with_concurrency_control(_amdelete()) - - async def ayield_keys(self, *, prefix: str | None = None) -> Iterator[str]: - """异步迭代所有键。 - - 注意:这是一个异步生成器,需要使用 async for 迭代。 - """ - pool = await self._get_pool() - async with pool.acquire() as conn: - if prefix: - rows = await conn.fetch( - f"SELECT key FROM {self.table_name} WHERE key LIKE $1 ORDER BY key", - f"{prefix}%" - ) - else: - rows = await conn.fetch( - f"SELECT key FROM {self.table_name} ORDER BY key" - ) - - for row in rows: - yield row['key'] - - async def aclose(self) -> None: - """异步关闭连接池,释放资源。""" - if self._pool: - await self._pool.close() - self._pool = None - logger.info("PostgreSQL 异步连接池已关闭") - - def close(self) -> None: - """同步关闭连接池(功能有限)。 - - 注意:在异步环境中,请使用 aclose 方法。 - """ - pass diff --git a/rag_core/vector_store.py b/rag_core/vector_store.py deleted file mode 100644 index b92e113..0000000 --- a/rag_core/vector_store.py +++ /dev/null @@ -1,181 +0,0 @@ -""" -Qdrant 向量数据库包装器。 -""" - -import logging -import os -import time -from typing import List, Optional, Dict, Any - -from langchain_core.documents import Document -from langchain_qdrant import QdrantVectorStore as LangchainQdrantVS -from qdrant_client import QdrantClient -from qdrant_client.http.models import Distance, VectorParams -from httpx import RemoteProtocolError -from qdrant_client.http.exceptions import ResponseHandlingException -from rag_core.client import create_qdrant_client - -logger = logging.getLogger(__name__) - -QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333") -QDRANT_API_KEY = os.getenv("QDRANT_API_KEY") - - -class QdrantVectorStore: - """Qdrant 向量数据库操作包装器。""" - - def __init__( - self, - collection_name: str, - embeddings: Optional[Any] = None, - ): - self.collection_name = collection_name - self._client: Optional[QdrantClient] = None - self._connection_attempts = 0 - self._last_connection_time: Optional[float] = None - - if embeddings is None: - from rag_core.embedders import LlamaCppEmbedder - embedder = LlamaCppEmbedder() - self.embeddings = embedder.as_langchain_embeddings() - else: - self.embeddings = embeddings - - self.create_collection() - - self.vector_store = LangchainQdrantVS( - client=self.get_client(), - collection_name=self.collection_name, - embedding=self.embeddings, - ) - - def get_client(self) -> QdrantClient: - if self._client is None: - self._client = create_qdrant_client(timeout=300) - self._connection_attempts += 1 - self._last_connection_time = time.time() - logger.debug("Qdrant 客户端已创建 (第 %d 次连接)", self._connection_attempts) - return self._client - - def refresh_client(self): - """关闭旧连接,创建新连接。""" - if self._client is not None: - try: - self._client.close() - logger.debug("Qdrant 旧连接已关闭") - except Exception as e: - logger.warning("关闭 Qdrant 连接时出现异常: %s", e) - finally: - self._client = None - self._last_connection_time = None - - def check_connection_health(self) -> bool: - """检查连接健康状态,如果连接已失效则自动重建。""" - if self._client is None: - logger.info("Qdrant 客户端未初始化,将创建新连接") - return False - - try: - client = self.get_client() - client.get_collections() - logger.debug("Qdrant 连接健康检查通过") - return True - except (RemoteProtocolError, ConnectionError, OSError, ResponseHandlingException) as e: - logger.warning("Qdrant 连接健康检查失败: %s", e) - self.refresh_client() - return False - - def get_connection_stats(self) -> Dict[str, Any]: - """获取连接统计信息。""" - return { - "connection_attempts": self._connection_attempts, - "last_connection_time": self._last_connection_time, - "client_initialized": self._client is not None, - } - - def create_collection(self, vector_size: Optional[int] = None, force_recreate: bool = False): - """创建集合,设置合适的向量维度。""" - if vector_size is None: - from rag_core.embedders import LlamaCppEmbedder - embedder = LlamaCppEmbedder() - vector_size = embedder.get_embedding_dimension() - - max_retries = 3 - base_delay = 2 - for attempt in range(max_retries): - try: - client = self.get_client() - collections = client.get_collections().collections - exists = any(c.name == self.collection_name for c in collections) - - if exists and force_recreate: - client.delete_collection(self.collection_name) - exists = False - - if not exists: - client.create_collection( - collection_name=self.collection_name, - vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE), - ) - logger.info("集合 '%s' 已创建(维度=%d)", self.collection_name, vector_size) - else: - logger.info("集合 '%s' 已存在", self.collection_name) - return - except (RemoteProtocolError, ConnectionError, OSError, ResponseHandlingException) as e: - if attempt == max_retries - 1: - logger.error("创建集合 '%s' 重试 %d 次后仍然失败: %s", self.collection_name, max_retries, e) - raise - wait_time = base_delay * (2 ** attempt) - error_type = type(e).__name__ - logger.warning( - "创建集合 '%s' 遇到网络异常 [%s],%d秒后重试 (%d/%d): %s", - self.collection_name, error_type, wait_time, attempt + 1, max_retries, e - ) - self.refresh_client() - logger.debug("已刷新 Qdrant 客户端连接") - time.sleep(wait_time) - - def add_documents(self, documents: List[Document], batch_size: int = 100): - """将文档添加到向量数据库。""" - if not documents: - return [] - self.create_collection() - ids = self.vector_store.add_documents(documents, batch_size=batch_size) - logger.info("已向 '%s' 添加 %d 个文档", self.collection_name, len(ids)) - return ids - - def similarity_search(self, query: str, k: int = 5) -> List[Document]: - return self.vector_store.similarity_search(query, k=k) - - def similarity_search_with_score(self, query: str, k: int = 5) -> List[tuple[Document, float]]: - return self.vector_store.similarity_search_with_score(query, k=k) - - def delete_collection(self): - self.get_client().delete_collection(self.collection_name) - logger.info("集合 '%s' 已删除", self.collection_name) - - def get_collection_info(self) -> Dict[str, Any]: - info = self.get_client().get_collection(self.collection_name) - vectors_config = info.config.params.vectors - if isinstance(vectors_config, dict): - first_config = next(iter(vectors_config.values()), None) - vector_size = first_config.size if first_config else 0 - else: - vector_size = vectors_config.size if vectors_config else 0 - return { - "name": self.collection_name, - "vectors_count": info.points_count or 0, - "status": info.status, - "vector_size": vector_size, - } - - def as_langchain_vectorstore(self): - return self.vector_store - - def get_langchain_vectorstore(self): - """返回 LangChain Qdrant 向量存储对象(别名)""" - return self.vector_store - - def get_qdrant_client(self): - """返回原生 Qdrant 客户端(如需手动管理 collection)""" - return self.get_client() \ No newline at end of file diff --git a/rag_indexer/__init__.py b/rag_indexer/__init__.py index 21ca58c..2a0117f 100644 --- a/rag_indexer/__init__.py +++ b/rag_indexer/__init__.py @@ -23,9 +23,9 @@ Offline RAG Indexer module. >>> await builder.build_from_file("document.pdf") """ -from rag_indexer.index_builder import IndexBuilder, IndexBuilderConfig, DocstoreConfig -from rag_indexer.loaders import DocumentLoader -from rag_indexer.splitters import SplitterType, get_splitter +from .index_builder import IndexBuilder, IndexBuilderConfig, DocstoreConfig +from .loaders import DocumentLoader +from .splitters import SplitterType, get_splitter # 从 rag_core 重新导出常用组件 from rag_core import ( diff --git a/rag_indexer/cli.py b/rag_indexer/cli.py index 1ecc15d..e63d3e9 100755 --- a/rag_indexer/cli.py +++ b/rag_indexer/cli.py @@ -7,8 +7,12 @@ import logging import sys from pathlib import Path -from rag_indexer.index_builder import IndexBuilder, IndexBuilderConfig -from rag_indexer.splitters import SplitterType +# 添加项目根目录和 backend 目录到 Python 路径 +sys.path.insert(0, str(Path(__file__).parent.parent)) +sys.path.insert(0, str(Path(__file__).parent.parent / "backend")) + +from .index_builder import IndexBuilder, IndexBuilderConfig +from .splitters import SplitterType logging.basicConfig( level=logging.INFO, diff --git a/rag_indexer/index_builder.py b/rag_indexer/index_builder.py index 582f8d5..1be4633 100644 --- a/rag_indexer/index_builder.py +++ b/rag_indexer/index_builder.py @@ -6,10 +6,14 @@ import asyncio import logging -from dataclasses import dataclass, field +import sys from pathlib import Path +from dataclasses import dataclass, field from typing import List, Union, Optional, Any, Dict +# 添加 backend 目录到路径以导入 rag_core +sys.path.insert(0, str(Path(__file__).parent.parent / "backend")) + from httpx import RemoteProtocolError from langchain_core.documents import Document from langchain_core.embeddings import Embeddings @@ -17,8 +21,8 @@ from langchain_core.stores import BaseStore from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter from qdrant_client.http.exceptions import ResponseHandlingException -from rag_indexer.loaders import DocumentLoader -from rag_indexer.splitters import SplitterType, get_splitter +from .loaders import DocumentLoader +from .splitters import SplitterType, get_splitter from rag_core import LlamaCppEmbedder, QdrantVectorStore, create_docstore, create_parent_retriever logger = logging.getLogger(__name__) diff --git a/rag_indexer/requirements.txt b/rag_indexer/requirements.txt new file mode 100644 index 0000000..b7e65d4 --- /dev/null +++ b/rag_indexer/requirements.txt @@ -0,0 +1,33 @@ +# RAG Indexer - 本地索引工具依赖 +# 依赖 rag_core (从 ../backend/rag_core 导入) + +# Core +pydantic==2.12.5 +python-dotenv==1.2.2 +typing-extensions==4.15.0 + +# LangChain (用于文档处理) +langchain==1.2.15 +langchain-community==0.4.1 +langchain-core==1.2.28 +tiktoken>=0.12.0 + +# Vector DB +qdrant-client==1.17.1 + +# HTTP +httpx==0.28.1 + +# Utilities +tenacity==9.1.4 +rich==15.0.0 +PyYAML==6.0.3 +numpy>=1.26.2 + +# Document Processing +unstructured==0.22.21 +pypdf==6.10.0 +beautifulsoup4==4.14.3 +lxml==6.1.0 +pandas==3.0.2 +spacy==3.8.14 diff --git a/rag_indexer/test/test_inspect_vectors.py b/rag_indexer/test/test_inspect_vectors.py index 5e296c0..3d671a1 100644 --- a/rag_indexer/test/test_inspect_vectors.py +++ b/rag_indexer/test/test_inspect_vectors.py @@ -5,12 +5,10 @@ import sys import numpy as np from dotenv import load_dotenv from qdrant_client import QdrantClient - +from backend.rag_core import LlamaCppEmbedder +load_dotenv() sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../..")) -from rag_core import LlamaCppEmbedder - -load_dotenv() QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333") QDRANT_API_KEY = os.getenv("QDRANT_API_KEY") diff --git a/rag_indexer/test/test_refactored.py b/rag_indexer/test/test_refactored.py index f52cc9a..4649bd8 100644 --- a/rag_indexer/test/test_refactored.py +++ b/rag_indexer/test/test_refactored.py @@ -10,8 +10,8 @@ import sys # 添加项目根目录到 Python 路径 sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) -from rag_indexer.index_builder import IndexBuilder -from rag_indexer.splitters import SplitterType +from ..index_builder import IndexBuilder +from ..splitters import SplitterType async def test_index_builder(): """测试索引构建功能""" diff --git a/scripts/start.sh b/scripts/start.sh index 1d308d3..3092c46 100755 --- a/scripts/start.sh +++ b/scripts/start.sh @@ -1,7 +1,7 @@ #!/bin/bash # ============================================================================= # AI Agent 启动与管理脚本 -# 用法: ./start.sh [check|backend|frontend|both|docker-up|docker-down] +# 用法: ./backend/scripts/start.sh both [check|backend|frontend|both|docker-up|docker-down] # ============================================================================= set -e @@ -308,7 +308,7 @@ start_frontend() { set +a export PYTHONPATH="$PROJECT_DIR" - streamlit run frontend/frontend_main.py & + streamlit run frontend/src/frontend_main.py & FRONTEND_PID=$! echo -e "${GREEN}✓ 前端服务已启动 (PID: $FRONTEND_PID)${NC}" echo -e "${GREEN}✓ 访问地址:${NC}"