From 4e981e9dcf52e90cb4c19969a5eede1afe247b91 Mon Sep 17 00:00:00 2001 From: root <953994191@qq.com> Date: Mon, 20 Apr 2026 14:05:57 +0800 Subject: [PATCH] =?UTF-8?q?=E6=96=87=E4=BB=B6=E5=8F=98=E6=9B=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .env.docker | 3 + app/agent.py | 351 ------------------------ app/agent/agent.py | 166 +++++++++++ app/{ => agent}/history.py | 0 app/agent/llm_factory.py | 56 ++++ app/{ => agent}/prompts.py | 37 ++- app/agent/rag_initializer.py | 23 ++ app/backend.py | 4 +- app/graph/graph_builder.py | 79 ++++++ app/{ => graph}/graph_tools.py | 0 app/{nodes => graph}/retrieve_memory.py | 2 +- app/{ => graph}/state.py | 0 app/graph_builder.py | 2 +- app/nodes/__init__.py | 2 +- app/nodes/finalize.py | 2 +- app/nodes/llm_call.py | 4 +- app/nodes/router.py | 2 +- app/nodes/summarize.py | 2 +- app/nodes/tool_call.py | 2 +- app/rag/__init__.py | 4 +- app/rag/pipeline.py | 10 +- app/rag/reranker.py | 17 +- app/rag/tools.py | 62 ----- frontend/config.py | 5 +- rag_core/client.py | 13 +- rag_core/vector_store.py | 95 +++++-- rag_indexer/index_builder.py | 19 +- scripts/start.sh | 2 +- 28 files changed, 474 insertions(+), 490 deletions(-) delete mode 100644 app/agent.py create mode 100644 app/agent/agent.py rename app/{ => agent}/history.py (100%) create mode 100644 app/agent/llm_factory.py rename app/{ => agent}/prompts.py (62%) create mode 100644 app/agent/rag_initializer.py create mode 100644 app/graph/graph_builder.py rename app/{ => graph}/graph_tools.py (100%) rename app/{nodes => graph}/retrieve_memory.py (97%) rename app/{ => graph}/state.py (100%) diff --git a/.env.docker b/.env.docker index 49f53dc..d08f87b 100644 --- a/.env.docker +++ b/.env.docker @@ -37,6 +37,9 @@ VLLM_BASE_URL=http://host.docker.internal:18000/v1 # Embedding 服务 (embeddinggemma-300M GGUF) - 端口 8082 LLAMACPP_EMBEDDING_URL=http://host.docker.internal:18001/v1 +# Reranker 服务 (bge-reranker-v2-m3) - 端口 8083 +LLAMACPP_RERANKER_URL=http://host.docker.internal:18002/v1 + # ----------------------------------------------------------------------------- # Mem0 记忆层配置 # ----------------------------------------------------------------------------- diff --git a/app/agent.py b/app/agent.py deleted file mode 100644 index 29b91b5..0000000 --- a/app/agent.py +++ /dev/null @@ -1,351 +0,0 @@ -""" -AI Agent 服务类 - 支持多模型动态切换 -接收外部传入的 checkpointer,不负责管理连接生命周期 -""" - -import os -import json -from dotenv import load_dotenv -try: - from langchain_community.chat_models import ChatZhipuAI - HAS_ZHIPUAI = True -except ImportError: - HAS_ZHIPUAI = False - ChatZhipuAI = None - -try: - from langchain_openai import ChatOpenAI, OpenAIEmbeddings - HAS_OPENAI = True -except ImportError: - HAS_OPENAI = False - ChatOpenAI = None - OpenAIEmbeddings = None - -from pydantic import SecretStr -try: - from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver - HAS_POSTGRES_CHECKPOINT = True -except ImportError: - HAS_POSTGRES_CHECKPOINT = False - AsyncPostgresSaver = None - -# 本地模块 -from app.graph_builder import GraphBuilder, GraphContext -from app.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME -try: - from app.rag import RAGPipeline - from app.rag.tools import RAGTool - HAS_RAG = True -except ImportError as e: - HAS_RAG = False - RAGPipeline = None - RAGTool = None -from app.logger import debug, info, warning, error - - -load_dotenv() - - -class AIAgentService: - """异步 AI Agent 服务,支持多模型动态切换,使用外部传入的 checkpointer""" - - def __init__(self, checkpointer: AsyncPostgresSaver): - """ - 初始化服务 - Args: - checkpointer: 已经初始化的 AsyncPostgresSaver 实例 - """ - self.checkpointer = checkpointer - self.graphs = {} # 存储不同模型对应的 graph 实例 - self.rag = None # RAG 检索实例 - self.rag_tool = None # RAG 工具实例 - - def _create_zhipu_llm(self): - """创建智谱在线 LLM""" - if not HAS_ZHIPUAI: - raise ImportError("智谱AI支持不可用,请安装langchain-community包") - api_key = os.getenv("ZHIPUAI_API_KEY") - if not api_key: - raise ValueError("ZHIPUAI_API_KEY not set in environment") - return ChatZhipuAI( - model="glm-4.7-flash", - api_key=api_key, - temperature=0.1, - max_tokens=4096, - timeout=120.0, # 增加请求超时时间(秒),原为60秒 - max_retries=3, # 增加重试次数,原为2次 - streaming=True, # 确保开启流式输出 - ) - - def _create_deepseek_llm(self): - """创建 DeepSeek LLM(使用 OpenAI 兼容 API)""" - if not HAS_OPENAI: - raise ImportError("OpenAI兼容支持不可用,请安装langchain-openai包") - api_key = os.getenv("DEEPSEEK_API_KEY") - if not api_key: - raise ValueError("DEEPSEEK_API_KEY not set in environment") - return ChatOpenAI( - base_url="https://api.deepseek.com", - api_key=SecretStr(api_key), - model="deepseek-reasoner", # deepseek-chat: 非思考模式, deepseek-reasoner: 思考模式 - temperature=0.1, - max_tokens=4096, - timeout=60.0, # 请求超时时间(秒) - max_retries=2, # 失败后自动重试次数 - streaming=True, # 确保开启流式输出 - ) - - def _create_local_llm(self): - """创建本地 vLLM 服务 LLM""" - if not HAS_OPENAI: - raise ImportError("OpenAI兼容支持不可用,请安装langchain-openai包") - # vLLM 服务地址:优先从环境变量读取,适配 Docker、FRP 穿透和本地开发 - vllm_base_url = os.getenv( - "VLLM_BASE_URL", - "http://127.0.0.1:8081/v1" - ) - - return ChatOpenAI( - base_url=vllm_base_url, - api_key=SecretStr(os.getenv("LLAMACPP_API_KEY", "token-abc123")), - model="gemma-4-E2B-it", - timeout=60.0, # 请求超时时间(秒) - max_retries=2, # 失败后自动重试次数 - streaming=True, # 确保开启流式输出 - ) - - def _create_embeddings(self): - """创建嵌入模型""" - if not HAS_OPENAI: - raise ImportError("OpenAI兼容支持不可用,请安装langchain-openai包") - embedding_url = os.getenv( - "LLAMACPP_EMBEDDING_URL", - "http://127.0.0.1:8082/v1" - ) - return OpenAIEmbeddings( - openai_api_base=embedding_url, - openai_api_key=os.getenv("LLAMACPP_API_KEY", "token-abc123"), - model="text-embedding-ada-002", # 模型名称不重要,兼容即可 - ) - - async def initialize(self): - """预编译所有模型的 graph(使用传入的 checkpointer)""" - # 先初始化 RAG 检索系统 - if HAS_RAG and RAGPipeline is not None and RAGTool is not None: - try: - info("🔄 正在初始化 RAG 检索系统...") - embeddings = self._create_embeddings() - self.rag = RAGPipeline(embeddings=embeddings) - self.rag_tool = RAGTool(self.rag).get_tool() - info("✅ RAG 检索系统初始化成功") - except Exception as e: - warning(f"⚠️ RAG 检索系统初始化失败: {e}") - self.rag = None - self.rag_tool = None - else: - info("⏭️ RAG 检索系统不可用,跳过初始化") - self.rag = None - self.rag_tool = None - - model_configs = { - "local": self._create_local_llm, # 本地模型作为第一个 - "deepseek": self._create_deepseek_llm, # DeepSeek 作为中间 - "zhipu": self._create_zhipu_llm, # GLM-4.7 作为最后一个 - } - - for model_name, llm_creator in model_configs.items(): - try: - info(f"🔄 正在初始化模型 '{model_name}'...") - llm = llm_creator() - - # 构建工具列表:基础工具 + RAG工具(如果可用) - tools = AVAILABLE_TOOLS.copy() - tools_by_name = TOOLS_BY_NAME.copy() - - if self.rag_tool is not None: - tools.append(self.rag_tool) - tools_by_name[self.rag_tool.name] = self.rag_tool - - builder = GraphBuilder(llm, tools, tools_by_name).build() - graph = builder.compile(checkpointer=self.checkpointer) - self.graphs[model_name] = graph - info(f"✅ 模型 '{model_name}' 初始化成功") - except Exception as e: - import traceback - error_detail = traceback.format_exc() - warning(f"⚠️ 模型 '{model_name}' 初始化失败: {e}") - debug(f" 详细错误:\n{error_detail}") - - if not self.graphs: - raise RuntimeError("没有可用的模型,请检查配置。可能的原因:\n" - "1. ZHIPUAI_API_KEY 未配置或无效\n" - "2. DEEPSEEK_API_KEY 未配置或无效\n" - "3. vLLM 服务未启动或地址错误 (VLLM_BASE_URL)\n" - "4. 网络连接问题") - - return self - - async def process_message(self, message: str, thread_id: str, model: str = "local", user_id: str = "default_user") -> dict: - """ - 处理用户消息,返回包含回复、token统计和耗时的字典 - - Returns: - dict: { - "reply": str, # AI 回复内容 - "token_usage": dict, # Token 使用详情 - "elapsed_time": float # 调用耗时(秒) - } - """ - # 尝试使用指定模型,如果不可用则循环尝试其他模型 - if model not in self.graphs: - warning(f"警告: 模型 '{model}' 不可用,尝试切换到其他可用模型") - found = False - for available_model in self.graphs.keys(): - try: - # 这里可以添加额外的模型可用性检查逻辑 - model = available_model - found = True - info(f"已切换到可用模型: '{model}'") - break - except Exception as e: - warning(f"模型 '{available_model}' 也不可用: {str(e)}") - continue - - if not found: - raise RuntimeError(f"错误: 没有任何可用的模型。当前注册的模型: {list(self.graphs.keys())}") - - graph = self.graphs[model] - config = { - "configurable": {"thread_id": thread_id}, - "metadata": {"user_id": user_id} # 写入 metadata 供历史查询使用 - } - input_state = {"messages": [{"role": "user", "content": message}]} - context = GraphContext(user_id=user_id) - - result = await graph.ainvoke(input_state, config=config, context=context) - - reply = result["messages"][-1].content - token_usage = result.get("last_token_usage", {}) - elapsed_time = result.get("last_elapsed_time", 0.0) - - return { - "reply": reply, - "token_usage": token_usage, - "elapsed_time": elapsed_time - } - - def _serialize_value(self, value): - """递归将 LangChain 对象转换为可 JSON 序列化的格式""" - if hasattr(value, 'content'): - # LangChain 消息对象 - msg_type = getattr(value, 'type', 'message') - return { - "role": msg_type, - "content": getattr(value, 'content', ''), - "additional_kwargs": getattr(value, 'additional_kwargs', {}), - "tool_calls": getattr(value, 'tool_calls', []) - } - elif isinstance(value, dict): - return {k: self._serialize_value(v) for k, v in value.items()} - elif isinstance(value, (list, tuple)): - return [self._serialize_value(item) for item in value] - else: - try: - json.dumps(value) - return value - except (TypeError, ValueError): - return str(value) - - async def process_message_stream(self, message: str, thread_id: str, model_name: str, user_id: str = "default_user"): - """ - 流式处理消息,返回异步生成器 - - Args: - message: 用户消息 - thread_id: 线程 ID - model_name: 模型名称 - user_id: 用户 ID - - Yields: - 字典,包含事件类型和数据 - """ - graph = self.graphs.get(model_name) - - if not graph: - raise ValueError(f"模型 '{model_name}' 未找到或未初始化") - - config = { - "configurable": {"thread_id": thread_id}, - "metadata": {"user_id": user_id} - } - input_state = {"messages": [{"role": "user", "content": message}]} - context = GraphContext(user_id=user_id) - - async for chunk in graph.astream( - input_state, - config=config, - context=context, - stream_mode=["messages", "updates", "custom"], # 组合多种模式,添加 custom - version="v2", # 使用统一的v2格式 - subgraphs=True # 如果你使用了子图,请开启此项 - ): - chunk_type = chunk["type"] - processed_event = {} - - # 1. 处理 LLM Token 流 (实现打字机效果) - if chunk_type == "messages": - message_chunk, metadata = chunk["data"] - - # 提取元数据 - node_name = metadata.get("langgraph_node", "unknown") - # 使用 getattr 安全地获取内容,因为 message_chunk 可能不是字符串 - token_content = getattr(message_chunk, 'content', str(message_chunk)) - - # 提取 DeepSeek reasoner 的思考过程 token - reasoning_token = "" - if hasattr(message_chunk, 'additional_kwargs'): - reasoning_token = message_chunk.additional_kwargs.get("reasoning_content", "") - - # [DEBUG] 临时添加:只在 reasoning_token 不为空时打印,方便你直观地看到它 - if reasoning_token: - import logging - logging.debug(f"💡 [Reasoning Token 捕获]: {repr(reasoning_token)}") - - processed_event = { - "type": "llm_token", - "node": node_name, - "token": token_content, - "reasoning_token": reasoning_token, - "metadata": metadata # 可选的元数据 - } - - # 2. 处理状态更新 (节点执行完成) - elif chunk_type == "updates": - updates_data = chunk["data"] - # 序列化 updates 中的所有数据 - serialized_data = self._serialize_value(updates_data) - processed_event = { - "type": "state_update", - "data": serialized_data - } - # 为了兼容前端旧字段,也保留 messages 字段(可选) - if "messages" in serialized_data: - processed_event["messages"] = serialized_data["messages"] - - # 3. 处理自定义数据 (如果需要) - elif chunk_type == "custom": - # 自定义事件同样需要序列化 - serialized_data = self._serialize_value(chunk["data"]) - processed_event = { - "type": "custom", - "data": serialized_data - } - - # 4. 其他类型(debug, tasks等)按需处理 - else: - # 对于不需要的类型,直接跳过 - continue - - # 确保事件有数据再发送 - if processed_event: - yield processed_event \ No newline at end of file diff --git a/app/agent/agent.py b/app/agent/agent.py new file mode 100644 index 0000000..cefb43f --- /dev/null +++ b/app/agent/agent.py @@ -0,0 +1,166 @@ +""" +AI Agent 服务类 - 支持多模型动态切换 +接收外部传入的 checkpointer,不负责管理连接生命周期 +""" + +import os +import json +from dotenv import load_dotenv + +from langchain_community.chat_models import ChatZhipuAI +from langchain_openai import ChatOpenAI, OpenAIEmbeddings +from pydantic import SecretStr +from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver + +# 本地模块 +from app.graph_builder import GraphBuilder, GraphContext +from app.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME +from app.rag import RAGPipeline +from app.rag.tools import create_rag_tool_sync +from rag_core import create_parent_retriever +from app.llm_factory import LLMFactory +from app.rag_initializer import init_rag_tool +from app.logger import debug, info, warning, error +load_dotenv() + + +class AIAgentService: + def __init__(self, checkpointer): + self.checkpointer = checkpointer + self.graphs = {} + self.tools = AVAILABLE_TOOLS.copy() + self.tools_by_name = TOOLS_BY_NAME.copy() + + async def initialize(self): + # 1. 初始化 RAG 工具(如果需要) + rag_tool = await init_rag_tool(LLMFactory.create_local) + if rag_tool: + self.tools.append(rag_tool) + self.tools_by_name[rag_tool.name] = rag_tool + + # 2. 构建各模型的 Graph + for name, creator in LLMFactory.CREATORS.items(): + try: + info(f"🔄 初始化模型 '{name}'...") + llm = creator() + builder = GraphBuilder(llm, self.tools, self.tools_by_name).build() + graph = builder.compile(checkpointer=self.checkpointer) + self.graphs[name] = graph + info(f"✅ 模型 '{name}' 初始化成功") + except Exception as e: + warning(f"⚠️ 模型 '{name}' 初始化失败: {e}") + + if not self.graphs: + raise RuntimeError("没有可用的模型") + return self + + async def process_message(self, message: str, thread_id: str, model: str = "local", user_id: str = "default_user") -> dict: + """处理用户消息,返回包含回复、token统计和耗时的字典""" + if model not in self.graphs: + # 回退到第一个可用模型 + available = list(self.graphs.keys()) + if not available: + raise RuntimeError("没有可用的模型") + model = available[0] + warning(f"模型 '{model}' 不可用,已回退到 '{model}'") + + graph = self.graphs[model] + config = { + "configurable": {"thread_id": thread_id}, + "metadata": {"user_id": user_id} + } + input_state = {"messages": [{"role": "user", "content": message}]} + context = GraphContext(user_id=user_id) + + result = await graph.ainvoke(input_state, config=config, context=context) + + reply = result["messages"][-1].content + token_usage = result.get("last_token_usage", {}) + elapsed_time = result.get("last_elapsed_time", 0.0) + + return { + "reply": reply, + "token_usage": token_usage, + "elapsed_time": elapsed_time + } + + def _serialize_value(self, value): + """递归将 LangChain 对象转换为可 JSON 序列化的格式""" + if hasattr(value, 'content'): + msg_type = getattr(value, 'type', 'message') + return { + "role": msg_type, + "content": getattr(value, 'content', ''), + "additional_kwargs": getattr(value, 'additional_kwargs', {}), + "tool_calls": getattr(value, 'tool_calls', []) + } + elif isinstance(value, dict): + return {k: self._serialize_value(v) for k, v in value.items()} + elif isinstance(value, (list, tuple)): + return [self._serialize_value(item) for item in value] + else: + try: + json.dumps(value) + return value + except (TypeError, ValueError): + return str(value) + + async def process_message_stream(self, message: str, thread_id: str, model_name: str, user_id: str = "default_user"): + """流式处理消息,返回异步生成器""" + graph = self.graphs.get(model_name) + if not graph: + raise ValueError(f"模型 '{model_name}' 未找到或未初始化") + + config = { + "configurable": {"thread_id": thread_id}, + "metadata": {"user_id": user_id} + } + input_state = {"messages": [{"role": "user", "content": message}]} + context = GraphContext(user_id=user_id) + + async for chunk in graph.astream( + input_state, + config=config, + context=context, + stream_mode=["messages", "updates", "custom"], + version="v2", + subgraphs=True + ): + chunk_type = chunk["type"] + processed_event = {} + + if chunk_type == "messages": + message_chunk, metadata = chunk["data"] + node_name = metadata.get("langgraph_node", "unknown") + token_content = getattr(message_chunk, 'content', str(message_chunk)) + reasoning_token = "" + if hasattr(message_chunk, 'additional_kwargs'): + reasoning_token = message_chunk.additional_kwargs.get("reasoning_content", "") + + processed_event = { + "type": "llm_token", + "node": node_name, + "token": token_content, + "reasoning_token": reasoning_token, + "metadata": metadata + } + elif chunk_type == "updates": + updates_data = chunk["data"] + serialized_data = self._serialize_value(updates_data) + processed_event = { + "type": "state_update", + "data": serialized_data + } + if "messages" in serialized_data: + processed_event["messages"] = serialized_data["messages"] + elif chunk_type == "custom": + serialized_data = self._serialize_value(chunk["data"]) + processed_event = { + "type": "custom", + "data": serialized_data + } + else: + continue + + if processed_event: + yield processed_event \ No newline at end of file diff --git a/app/history.py b/app/agent/history.py similarity index 100% rename from app/history.py rename to app/agent/history.py diff --git a/app/agent/llm_factory.py b/app/agent/llm_factory.py new file mode 100644 index 0000000..9a1a22a --- /dev/null +++ b/app/agent/llm_factory.py @@ -0,0 +1,56 @@ +# app/llm_factory.py +import os +from langchain_community.chat_models import ChatZhipuAI +from langchain_openai import ChatOpenAI +from pydantic import SecretStr + +class LLMFactory: + @staticmethod + def create_zhipu(): + api_key = os.getenv("ZHIPUAI_API_KEY") + if not api_key: + raise ValueError("ZHIPUAI_API_KEY not set") + return ChatZhipuAI( + model="glm-4.7-flash", + api_key=api_key, + temperature=0.1, + max_tokens=4096, + timeout=120.0, + max_retries=3, + streaming=True, + ) + + @staticmethod + def create_deepseek(): + api_key = os.getenv("DEEPSEEK_API_KEY") + if not api_key: + raise ValueError("DEEPSEEK_API_KEY not set") + return ChatOpenAI( + base_url="https://api.deepseek.com", + api_key=SecretStr(api_key), + model="deepseek-reasoner", + temperature=0.1, + max_tokens=4096, + timeout=60.0, + max_retries=2, + streaming=True, + ) + + @staticmethod + def create_local(): + base_url = os.getenv("VLLM_BASE_URL", "http://127.0.0.1:8081/v1") + return ChatOpenAI( + base_url=base_url, + api_key=SecretStr(os.getenv("LLAMACPP_API_KEY", "token-abc123")), + model="gemma-4-E4B-it", + timeout=60.0, + max_retries=2, + streaming=True, + ) + + # 模型创建器映射 + CREATORS = { + "local": create_local, + "deepseek": create_deepseek, + "zhipu": create_zhipu, + } \ No newline at end of file diff --git a/app/prompts.py b/app/agent/prompts.py similarity index 62% rename from app/prompts.py rename to app/agent/prompts.py index f0f2c74..990e634 100644 --- a/app/prompts.py +++ b/app/agent/prompts.py @@ -1,18 +1,21 @@ -""" -提示模板管理模块 -所有系统提示和对话模板统一定义 -""" - +# app/prompts.py from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from langchain_core.tools import BaseTool +def create_system_prompt(tools: list = None) -> ChatPromptTemplate: + """ + 创建系统提示模板,可选择动态注入工具描述。 + """ + tools_section = "" + if tools: + tool_descs = [] + for tool in tools: + # 提取工具名称和描述的第一行 + name = getattr(tool, 'name', tool.__name__) + desc = (tool.description or "").split('\n')[0] + tool_descs.append(f"- {name}: {desc}") + tools_section = "\n".join(tool_descs) -def create_system_prompt() -> ChatPromptTemplate: - """ - 创建系统提示模板 - - Returns: - ChatPromptTemplate: 包含系统指令和消息占位符的提示模板 - """ system_template = ( "你是一个个人生活助手和数据分析助手,请使用中文交流。\n\n" "【用户背景信息】\n" @@ -20,15 +23,11 @@ def create_system_prompt() -> ChatPromptTemplate: "{memory_context}\n" "若包含姓名、偏好等个人信息,请自然融入回应(例如称呼名字、提及偏好)。\n\n" "【可用工具与使用规则】\n" - "- 获取温度/天气:`get_current_temperature`\n" - "- 读取文本文件:`read_local_file`(限定目录 `./user_docs`)\n" - "- 读取PDF摘要:`read_pdf_summary`(限定目录 `./user_docs`)\n" - "- 读取Excel表格:`read_excel_as_markdown`(限定目录 `./user_docs`)\n" - "- 抓取网页内容:`fetch_webpage_content`\n" + f"{tools_section}\n" "工具调用时请直接返回所需参数,无需额外说明。\n\n" "【回答要求(必须遵守)】\n" "1. 回答必须简洁、直接。\n" - "2. 如果你认为该问题需要进行深入的推理或思考,请务必将你的思维链或推理过程用 `` 和 `` 标签包裹起来,放在回答的最前面。例如:这里是我的思考过程...这里是最终回答。\n" + "2. 如果你认为该问题需要进行深入的推理或思考,请务必将你的思维链或推理过程用 `` 和 `` 标签包裹起来,放在回答的最前面。\n" "3. 优先利用已知用户信息进行个性化回复。\n" "4. 若无信息可依,礼貌询问或提供通用帮助。" ) @@ -36,4 +35,4 @@ def create_system_prompt() -> ChatPromptTemplate: return ChatPromptTemplate.from_messages([ ("system", system_template), MessagesPlaceholder(variable_name="messages") - ]) + ]) \ No newline at end of file diff --git a/app/agent/rag_initializer.py b/app/agent/rag_initializer.py new file mode 100644 index 0000000..f391b8f --- /dev/null +++ b/app/agent/rag_initializer.py @@ -0,0 +1,23 @@ +# app/rag_initializer.py +from app.rag.tools import create_rag_tool_sync +from rag_core import create_parent_retriever +from app.logger import info, warning + +async def init_rag_tool(local_llm_creator): + """初始化 RAG 工具,失败返回 None""" + try: + info("🔄 正在初始化 RAG 检索系统...") + retriever = create_parent_retriever( + collection_name="rag_documents", + search_k=5, + ) + rewrite_llm = local_llm_creator() + rag_tool = create_rag_tool_sync( + retriever, rewrite_llm, + num_queries=3, rerank_top_n=5 + ) + info("✅ RAG 检索工具初始化成功") + return rag_tool + except Exception as e: + warning(f"⚠️ RAG 检索工具初始化失败: {e}") + return None \ No newline at end of file diff --git a/app/backend.py b/app/backend.py index f531a36..20da307 100644 --- a/app/backend.py +++ b/app/backend.py @@ -231,6 +231,6 @@ async def websocket_endpoint( if __name__ == "__main__": import uvicorn - # 使用环境变量或默认端口 8083(避免与 llama.cpp 的 8081 端口冲突) - port = int(os.getenv("BACKEND_PORT", "8083")) + # 使用环境变量或默认端口 8079(避免与 llama.cpp 的 8081 端口冲突) + port = int(os.getenv("BACKEND_PORT", "8079")) uvicorn.run(app, host="0.0.0.0", port=port) diff --git a/app/graph/graph_builder.py b/app/graph/graph_builder.py new file mode 100644 index 0000000..11f9c9d --- /dev/null +++ b/app/graph/graph_builder.py @@ -0,0 +1,79 @@ +""" +LangGraph 状态图构建模块 - 精简版,仅负责组装图 +所有节点逻辑已拆分到独立模块 +""" + +from langchain_core.language_models import BaseLLM +from langgraph.graph import StateGraph, START, END + +# 本地模块 +from app.graph.state import MessagesState, GraphContext +from app.nodes import ( + create_llm_call_node, + create_tool_call_node, + create_retrieve_memory_node, + create_summarize_node, + should_continue +) +from app.memory import Mem0Client +from app.nodes.finalize import finalize_node + + +class GraphBuilder: + """LangGraph 状态图构建器 - 仅负责组装图""" + + def __init__(self, llm: BaseLLM, tools: list, tools_by_name: dict): + """ + 初始化构建器 + + Args: + llm: 大语言模型实例 + tools: 工具列表 + tools_by_name: 名称到工具函数的映射 + """ + self.llm = llm + self.tools = tools + self.tools_by_name = tools_by_name + + # ⭐ 创建 Mem0 客户端(懒加载,首次使用时初始化) + self.mem0_client = Mem0Client(llm) + + def build(self) -> StateGraph: + """ + 构建未编译的状态图 + + Returns: + StateGraph 实例 + """ + builder = StateGraph(MessagesState, context_schema=GraphContext) + + # ⭐ 通过工厂函数创建节点(依赖注入) + retrieve_memory_node = create_retrieve_memory_node(self.mem0_client) + llm_call_node = create_llm_call_node(self.llm, self.tools) + tool_call_node = create_tool_call_node(self.tools_by_name) + summarize_node = create_summarize_node(self.mem0_client) + + # 添加节点 + builder.add_node("retrieve_memory", retrieve_memory_node) + builder.add_node("llm_call", llm_call_node) + builder.add_node("tool_node", tool_call_node) + builder.add_node("summarize", summarize_node) + builder.add_node("finalize", finalize_node) + + # 添加边 + builder.add_edge(START, "retrieve_memory") + builder.add_edge("retrieve_memory", "llm_call") + builder.add_conditional_edges( + "llm_call", + should_continue, + { + "tool_node": "tool_node", + "summarize": "summarize", + "finalize": "finalize" + } + ) + builder.add_edge("tool_node", "llm_call") + builder.add_edge("summarize", "finalize") + builder.add_edge("finalize", END) + + return builder \ No newline at end of file diff --git a/app/graph_tools.py b/app/graph/graph_tools.py similarity index 100% rename from app/graph_tools.py rename to app/graph/graph_tools.py diff --git a/app/nodes/retrieve_memory.py b/app/graph/retrieve_memory.py similarity index 97% rename from app/nodes/retrieve_memory.py rename to app/graph/retrieve_memory.py index 0313ca8..61434d0 100644 --- a/app/nodes/retrieve_memory.py +++ b/app/graph/retrieve_memory.py @@ -7,7 +7,7 @@ from typing import Any, Dict from langgraph.runtime import Runtime # 本地模块 -from app.state import MessagesState, GraphContext +from app.graph.state import MessagesState, GraphContext from app.memory.mem0_client import Mem0Client from app.utils.logging import log_state_change from app.logger import debug diff --git a/app/state.py b/app/graph/state.py similarity index 100% rename from app/state.py rename to app/graph/state.py diff --git a/app/graph_builder.py b/app/graph_builder.py index 2af230a..11f9c9d 100644 --- a/app/graph_builder.py +++ b/app/graph_builder.py @@ -7,7 +7,7 @@ from langchain_core.language_models import BaseLLM from langgraph.graph import StateGraph, START, END # 本地模块 -from app.state import MessagesState, GraphContext +from app.graph.state import MessagesState, GraphContext from app.nodes import ( create_llm_call_node, create_tool_call_node, diff --git a/app/nodes/__init__.py b/app/nodes/__init__.py index 8d279db..d9eb644 100644 --- a/app/nodes/__init__.py +++ b/app/nodes/__init__.py @@ -5,7 +5,7 @@ from app.nodes.router import should_continue from app.nodes.llm_call import create_llm_call_node from app.nodes.tool_call import create_tool_call_node -from app.nodes.retrieve_memory import create_retrieve_memory_node +from app.graph.retrieve_memory import create_retrieve_memory_node from app.nodes.summarize import create_summarize_node from app.nodes.finalize import finalize_node diff --git a/app/nodes/finalize.py b/app/nodes/finalize.py index a283587..ea6e1dc 100644 --- a/app/nodes/finalize.py +++ b/app/nodes/finalize.py @@ -8,7 +8,7 @@ from langgraph.runtime import Runtime from langgraph.config import get_stream_writer # 本地模块 -from app.state import MessagesState, GraphContext +from app.graph.state import MessagesState, GraphContext from app.utils.logging import log_state_change from app.logger import info, error diff --git a/app/nodes/llm_call.py b/app/nodes/llm_call.py index 79361cd..22abfed 100644 --- a/app/nodes/llm_call.py +++ b/app/nodes/llm_call.py @@ -12,7 +12,7 @@ from langchain_core.runnables import RunnableLambda from langgraph.runtime import Runtime # 本地模块 -from app.state import MessagesState, GraphContext +from app.graph.state import MessagesState, GraphContext from app.prompts import create_system_prompt from app.utils.logging import log_state_change, print_llm_input from app.logger import debug, info, error @@ -30,7 +30,7 @@ def create_llm_call_node(llm: BaseLLM, tools: list): 异步节点函数 """ # 构建调用链 - prompt = create_system_prompt() + prompt = create_system_prompt(tools) llm_with_tools = llm.bind_tools(tools) # 恢复带 RunnableLambda 的链,并在下方使用 astream 遍历 diff --git a/app/nodes/router.py b/app/nodes/router.py index 81d2e7e..cabc275 100644 --- a/app/nodes/router.py +++ b/app/nodes/router.py @@ -8,7 +8,7 @@ from langchain_core.messages import AIMessage # 本地模块 from app.config import ENABLE_GRAPH_TRACE, MEMORY_SUMMARIZE_INTERVAL -from app.state import MessagesState +from app.graph.state import MessagesState from app.logger import info diff --git a/app/nodes/summarize.py b/app/nodes/summarize.py index a49742a..8b39baa 100644 --- a/app/nodes/summarize.py +++ b/app/nodes/summarize.py @@ -7,7 +7,7 @@ from typing import Any, Dict from langgraph.runtime import Runtime # 本地模块 -from app.state import MessagesState, GraphContext +from app.graph.state import MessagesState, GraphContext from app.memory.mem0_client import Mem0Client from app.utils.logging import log_state_change from app.logger import debug, info, error, warning diff --git a/app/nodes/tool_call.py b/app/nodes/tool_call.py index 12648ff..1c9d55f 100644 --- a/app/nodes/tool_call.py +++ b/app/nodes/tool_call.py @@ -10,7 +10,7 @@ from langgraph.runtime import Runtime from langgraph.config import get_stream_writer # 本地模块 -from app.state import MessagesState, GraphContext +from app.graph.state import MessagesState, GraphContext from app.utils.logging import log_state_change from app.logger import debug, info diff --git a/app/rag/__init__.py b/app/rag/__init__.py index 8b4868f..1438604 100644 --- a/app/rag/__init__.py +++ b/app/rag/__init__.py @@ -39,7 +39,7 @@ from .retriever import ( create_hybrid_retriever, create_qdrant_client, ) -from .reranker import CrossEncoderReranker +from .reranker import LLaMaCPPReranker from .query_transform import MultiQueryGenerator from .fusion import reciprocal_rank_fusion from .pipeline import RAGPipeline @@ -53,7 +53,7 @@ __all__ = [ "create_qdrant_client", # 重排序器 - "CrossEncoderReranker", + "LLaMaCPPReranker", # 查询改写生成器 "MultiQueryGenerator", diff --git a/app/rag/pipeline.py b/app/rag/pipeline.py index c0b4e6f..c8d95d6 100644 --- a/app/rag/pipeline.py +++ b/app/rag/pipeline.py @@ -1,6 +1,7 @@ # rag/pipeline.py import asyncio +import os from typing import List, Optional from langchain_core.documents import Document from langchain_core.language_models import BaseLanguageModel @@ -23,7 +24,6 @@ class RAGPipeline: llm: BaseLanguageModel, num_queries: int = 3, rerank_top_n: int = 5, - rerank_model: str = "BAAI/bge-reranker-base", ): """ Args: @@ -41,9 +41,9 @@ class RAGPipeline: # 初始化组件 self.query_generator = MultiQueryGenerator(llm=llm, num_queries=num_queries) self.reranker = LLaMaCPPReranker( - base_url="http://127.0.0.1:8083", + base_url=os.getenv("LLAMACPP_RERANKER_URL", "http://127.0.0.1:8083"), + api_key=os.getenv("LLAMACPP_API_KEY", "huang1998"), top_n=rerank_top_n, - api_key="huang1998" ) async def aretrieve(self, query: str) -> List[Document]: @@ -68,9 +68,9 @@ class RAGPipeline: fused_docs = reciprocal_rank_fusion(doc_lists) # Step 4: 重排序 - if self.reranker.model is not None: + try: final_docs = self.reranker.compress_documents(fused_docs, query) - else: + except Exception: # 若重排序器不可用,直接返回融合后的前 N 条 final_docs = fused_docs[:self.rerank_top_n] diff --git a/app/rag/reranker.py b/app/rag/reranker.py index 7a53806..b6f7d4a 100644 --- a/app/rag/reranker.py +++ b/app/rag/reranker.py @@ -2,32 +2,33 @@ 重排序器模块 (适配版) 使用远程 llama.cpp 服务 (兼容 OpenAI Rerank API) 替代本地 Cross-Encoder """ +import os import requests -from typing import List +from typing import List, Optional from langchain_core.documents import Document class LLaMaCPPReranker: """使用远程 llama.cpp 服务对检索结果重排序。""" def __init__(self, - base_url: str = "http://127.0.0.1:8083", + base_url: str, + api_key: str, top_n: int = 5, - api_key: str = "huang1998", # 你设置的 LLAMA_ARG_API_KEY timeout: int = 60): """ 初始化远程重排序器 Args: - base_url: llama.cpp 服务的地址和端口。 + base_url: llama.cpp 服务的地址和端口,默认为环境变量 LLAMACPP_RERANKER_URL 或 "http://127.0.0.1:8083"。 top_n: 返回前 N 个结果。 - api_key: 在容器中设置的 API 密钥。 + api_key: API 密钥,默认为环境变量 LLAMACPP_API_KEY 或 "huang1998"。 timeout: 请求超时时间(秒)。 """ - self.base_url = base_url.rstrip('/') + self.base_url = base_url + self.api_key = api_key self.top_n = top_n - self.api_key = api_key self.timeout = timeout - self.endpoint = f"{self.base_url}/v1/rerank" + self.endpoint = f"{self.base_url}/rerank" def compress_documents( self, documents: List[Document], query: str diff --git a/app/rag/tools.py b/app/rag/tools.py index 32268ed..4934101 100644 --- a/app/rag/tools.py +++ b/app/rag/tools.py @@ -4,74 +4,12 @@ RAG 工具模块 将检索功能封装为 LangChain Tool,供 Agent 调用。 采用固定流水线:多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档。 """ - from typing import Optional, Callable from langchain_core.tools import tool from langchain_core.language_models import BaseLanguageModel from langchain_core.retrievers import BaseRetriever - from .pipeline import RAGPipeline - -def create_rag_tool( - retriever: BaseRetriever, - llm: BaseLanguageModel, - num_queries: int = 3, - rerank_top_n: int = 5, - collection_name: str = "rag_documents", -) -> Callable: - """ - 创建一个配置好的 RAG 检索工具(异步)。 - - Args: - retriever: 基础检索器(例如 ParentDocumentRetriever 实例) - llm: 用于多路查询改写的语言模型 - num_queries: 生成查询变体数量 - rerank_top_n: 最终返回的文档数量 - collection_name: 集合名称(仅用于日志/描述) - - Returns: - LangChain Tool 可调用对象(异步) - """ - # 初始化流水线(所有组件一次创建,后续复用) - pipeline = RAGPipeline( - retriever=retriever, - llm=llm, - num_queries=num_queries, - rerank_top_n=rerank_top_n, - ) - - @tool - async def search_knowledge_base(query: str) -> str: - """在知识库中搜索与查询相关的文档片段。 - - 该工具会: - 1. 将用户问题改写成多个不同角度的查询 - 2. 并行检索每个查询的相关父文档 - 3. 使用倒数排名融合(RRF)合并结果 - 4. 用 Cross-Encoder 重排序模型精选最相关的片段 - - 适用于需要精确、全面答案的事实性问题或背景知识查询。 - - Args: - query: 用户提出的问题或查询字符串 - - Returns: - 格式化后的相关文档内容,若无结果则返回提示信息。 - """ - try: - documents = await pipeline.aretrieve(query) - if not documents: - return f"在知识库 '{collection_name}' 中未找到与 '{query}' 相关的信息。" - - context = pipeline.format_context(documents) - return context - except Exception as e: - return f"检索过程中发生错误: {str(e)}" - - return search_knowledge_base - - def create_rag_tool_sync( retriever: BaseRetriever, llm: BaseLanguageModel, diff --git a/frontend/config.py b/frontend/config.py index b9b8c9d..da4c995 100644 --- a/frontend/config.py +++ b/frontend/config.py @@ -5,6 +5,7 @@ import os from dataclasses import dataclass +from typing import Optional from dotenv import load_dotenv # 加载 .env 文件 @@ -25,7 +26,7 @@ class FrontendConfig: # ==================== 模型配置 ==================== default_model: str = "local" # 更改为local作为默认模型 - model_options: dict = None + model_options: Optional[dict] = None # ==================== 用户配置 ==================== default_user_id: str = "default_user" @@ -53,7 +54,7 @@ class FrontendConfig: """从环境变量加载配置(优先级最高)""" # API 地址(移除 /chat 后缀) # 优先级:环境变量 API_URL > 默认值 - api_url = os.getenv("API_URL", "http://127.0.0.1:8083") + api_url = os.getenv("API_URL", "http://127.0.0.1:8079") self.api_base = api_url.replace("/chat", "").rstrip("/") diff --git a/rag_core/client.py b/rag_core/client.py index 3f313ca..109958a 100644 --- a/rag_core/client.py +++ b/rag_core/client.py @@ -9,16 +9,19 @@ QDRANT_API_KEY = os.getenv("QDRANT_API_KEY") def create_qdrant_client( url: Optional[str] = None, api_key: Optional[str] = None, - timeout: int = 120, # 索引构建需要较长超时 + timeout: int = 300, # 索引构建需要较长超时 ) -> QdrantClient: effective_url = url or QDRANT_URL effective_api_key = api_key or QDRANT_API_KEY - + if not effective_url: raise ValueError("Qdrant URL 未配置") - - client_kwargs = {"url": effective_url, "timeout": timeout} + + client_kwargs = { + "url": effective_url, + "timeout": timeout, + } if effective_api_key: client_kwargs["api_key"] = effective_api_key - + return QdrantClient(**client_kwargs) \ No newline at end of file diff --git a/rag_core/vector_store.py b/rag_core/vector_store.py index 7fd3080..13cbfbb 100644 --- a/rag_core/vector_store.py +++ b/rag_core/vector_store.py @@ -4,12 +4,15 @@ Qdrant 向量数据库包装器。 import logging import os +import time from typing import List, Optional, Dict, Any from langchain_core.documents import Document from langchain_qdrant import QdrantVectorStore as LangchainQdrantVS from qdrant_client import QdrantClient from qdrant_client.http.models import Distance, VectorParams +from httpx import RemoteProtocolError +from qdrant_client.http.exceptions import ResponseHandlingException from .client import create_qdrant_client logger = logging.getLogger(__name__) @@ -28,6 +31,8 @@ class QdrantVectorStore: ): self.collection_name = collection_name self._client: Optional[QdrantClient] = None + self._connection_attempts = 0 + self._last_connection_time: Optional[float] = None if embeddings is None: from .embedders import LlamaCppEmbedder @@ -46,14 +51,47 @@ class QdrantVectorStore: def get_client(self) -> QdrantClient: if self._client is None: - self._client = create_qdrant_client(timeout=120) + self._client = create_qdrant_client(timeout=300) + self._connection_attempts += 1 + self._last_connection_time = time.time() + logger.debug("Qdrant 客户端已创建 (第 %d 次连接)", self._connection_attempts) return self._client def refresh_client(self): """关闭旧连接,创建新连接。""" if self._client is not None: - self._client.close() - self._client = None + try: + self._client.close() + logger.debug("Qdrant 旧连接已关闭") + except Exception as e: + logger.warning("关闭 Qdrant 连接时出现异常: %s", e) + finally: + self._client = None + self._last_connection_time = None + + def check_connection_health(self) -> bool: + """检查连接健康状态,如果连接已失效则自动重建。""" + if self._client is None: + logger.info("Qdrant 客户端未初始化,将创建新连接") + return False + + try: + client = self.get_client() + client.get_collections() + logger.debug("Qdrant 连接健康检查通过") + return True + except (RemoteProtocolError, ConnectionError, OSError, ResponseHandlingException) as e: + logger.warning("Qdrant 连接健康检查失败: %s", e) + self.refresh_client() + return False + + def get_connection_stats(self) -> Dict[str, Any]: + """获取连接统计信息。""" + return { + "connection_attempts": self._connection_attempts, + "last_connection_time": self._last_connection_time, + "client_initialized": self._client is not None, + } def create_collection(self, vector_size: Optional[int] = None, force_recreate: bool = False): """创建集合,设置合适的向量维度。""" @@ -62,22 +100,40 @@ class QdrantVectorStore: embedder = LlamaCppEmbedder() vector_size = embedder.get_embedding_dimension() - client = self.get_client() - collections = client.get_collections().collections - exists = any(c.name == self.collection_name for c in collections) + max_retries = 3 + base_delay = 2 + for attempt in range(max_retries): + try: + client = self.get_client() + collections = client.get_collections().collections + exists = any(c.name == self.collection_name for c in collections) - if exists and force_recreate: - client.delete_collection(self.collection_name) - exists = False + if exists and force_recreate: + client.delete_collection(self.collection_name) + exists = False - if not exists: - client.create_collection( - collection_name=self.collection_name, - vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE), - ) - logger.info("集合 '%s' 已创建(维度=%d)", self.collection_name, vector_size) - else: - logger.info("集合 '%s' 已存在", self.collection_name) + if not exists: + client.create_collection( + collection_name=self.collection_name, + vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE), + ) + logger.info("集合 '%s' 已创建(维度=%d)", self.collection_name, vector_size) + else: + logger.info("集合 '%s' 已存在", self.collection_name) + return + except (RemoteProtocolError, ConnectionError, OSError, ResponseHandlingException) as e: + if attempt == max_retries - 1: + logger.error("创建集合 '%s' 重试 %d 次后仍然失败: %s", self.collection_name, max_retries, e) + raise + wait_time = base_delay * (2 ** attempt) + error_type = type(e).__name__ + logger.warning( + "创建集合 '%s' 遇到网络异常 [%s],%d秒后重试 (%d/%d): %s", + self.collection_name, error_type, wait_time, attempt + 1, max_retries, e + ) + self.refresh_client() + logger.debug("已刷新 Qdrant 客户端连接") + time.sleep(wait_time) def add_documents(self, documents: List[Document], batch_size: int = 100): """将文档添加到向量数据库。""" @@ -102,9 +158,10 @@ class QdrantVectorStore: info = self.get_client().get_collection(self.collection_name) vectors_config = info.config.params.vectors if isinstance(vectors_config, dict): - vector_size = next(iter(vectors_config.values())).size + first_config = next(iter(vectors_config.values()), None) + vector_size = first_config.size if first_config else 0 else: - vector_size = vectors_config.size + vector_size = vectors_config.size if vectors_config else 0 return { "name": self.collection_name, "vectors_count": info.points_count or 0, diff --git a/rag_indexer/index_builder.py b/rag_indexer/index_builder.py index a585970..137a674 100644 --- a/rag_indexer/index_builder.py +++ b/rag_indexer/index_builder.py @@ -16,6 +16,7 @@ from langchain_core.embeddings import Embeddings from langchain_core.stores import BaseStore from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter from langchain_classic.retrievers import ParentDocumentRetriever +from qdrant_client.http.exceptions import ResponseHandlingException from .loaders import DocumentLoader from .splitters import SplitterType, get_splitter, SemanticChunkerAdapter @@ -223,18 +224,26 @@ class IndexBuilder: async def _add_batch_with_retry(self, batch: List[Document], batch_no: int) -> None: """添加批次,失败时自动重试(处理网络波动)。""" - max_retries = 3 + max_retries = 5 + base_delay = 2 for attempt in range(max_retries): try: await self.retriever.aadd_documents(batch) # type: ignore[union-attr] + logger.info("批次 %d 成功添加 %d 个文档", batch_no, len(batch)) return - except (RemoteProtocolError, ConnectionError, OSError) as e: + except (RemoteProtocolError, ConnectionError, OSError, ResponseHandlingException) as e: if attempt == max_retries - 1: + logger.error("批次 %d 重试 %d 次后仍然失败: %s", batch_no, max_retries, e) raise - logger.warning("批次 %d 连接断开,重试 (%d/%d): %s", - batch_no, attempt + 1, max_retries, e) + wait_time = base_delay * (2 ** attempt) + error_type = type(e).__name__ + logger.warning( + "批次 %d 遇到网络异常 [%s],%d秒后重试 (%d/%d): %s", + batch_no, error_type, wait_time, attempt + 1, max_retries, e + ) self.vector_store.refresh_client() - await asyncio.sleep(1) + logger.debug("批次 %d 已刷新 Qdrant 客户端连接", batch_no) + await asyncio.sleep(wait_time) # ---------- 信息获取方法 ---------- def get_collection_info(self) -> Any: diff --git a/scripts/start.sh b/scripts/start.sh index 6ee87f5..3243bc2 100755 --- a/scripts/start.sh +++ b/scripts/start.sh @@ -288,7 +288,7 @@ start_backend() { set +a export PYTHONPATH="$PROJECT_DIR" - export BACKEND_PORT=8083 + export BACKEND_PORT=8079 python app/backend.py & BACKEND_PID=$! echo -e "${GREEN}✓ 后端服务已启动 (PID: $BACKEND_PID)${NC}"