From 4e981e9dcf52e90cb4c19969a5eede1afe247b91 Mon Sep 17 00:00:00 2001
From: root <953994191@qq.com>
Date: Mon, 20 Apr 2026 14:05:57 +0800
Subject: [PATCH] =?UTF-8?q?=E6=96=87=E4=BB=B6=E5=8F=98=E6=9B=B4?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
.env.docker | 3 +
app/agent.py | 351 ------------------------
app/agent/agent.py | 166 +++++++++++
app/{ => agent}/history.py | 0
app/agent/llm_factory.py | 56 ++++
app/{ => agent}/prompts.py | 37 ++-
app/agent/rag_initializer.py | 23 ++
app/backend.py | 4 +-
app/graph/graph_builder.py | 79 ++++++
app/{ => graph}/graph_tools.py | 0
app/{nodes => graph}/retrieve_memory.py | 2 +-
app/{ => graph}/state.py | 0
app/graph_builder.py | 2 +-
app/nodes/__init__.py | 2 +-
app/nodes/finalize.py | 2 +-
app/nodes/llm_call.py | 4 +-
app/nodes/router.py | 2 +-
app/nodes/summarize.py | 2 +-
app/nodes/tool_call.py | 2 +-
app/rag/__init__.py | 4 +-
app/rag/pipeline.py | 10 +-
app/rag/reranker.py | 17 +-
app/rag/tools.py | 62 -----
frontend/config.py | 5 +-
rag_core/client.py | 13 +-
rag_core/vector_store.py | 95 +++++--
rag_indexer/index_builder.py | 19 +-
scripts/start.sh | 2 +-
28 files changed, 474 insertions(+), 490 deletions(-)
delete mode 100644 app/agent.py
create mode 100644 app/agent/agent.py
rename app/{ => agent}/history.py (100%)
create mode 100644 app/agent/llm_factory.py
rename app/{ => agent}/prompts.py (62%)
create mode 100644 app/agent/rag_initializer.py
create mode 100644 app/graph/graph_builder.py
rename app/{ => graph}/graph_tools.py (100%)
rename app/{nodes => graph}/retrieve_memory.py (97%)
rename app/{ => graph}/state.py (100%)
diff --git a/.env.docker b/.env.docker
index 49f53dc..d08f87b 100644
--- a/.env.docker
+++ b/.env.docker
@@ -37,6 +37,9 @@ VLLM_BASE_URL=http://host.docker.internal:18000/v1
# Embedding 服务 (embeddinggemma-300M GGUF) - 端口 8082
LLAMACPP_EMBEDDING_URL=http://host.docker.internal:18001/v1
+# Reranker 服务 (bge-reranker-v2-m3) - 端口 8083
+LLAMACPP_RERANKER_URL=http://host.docker.internal:18002/v1
+
# -----------------------------------------------------------------------------
# Mem0 记忆层配置
# -----------------------------------------------------------------------------
diff --git a/app/agent.py b/app/agent.py
deleted file mode 100644
index 29b91b5..0000000
--- a/app/agent.py
+++ /dev/null
@@ -1,351 +0,0 @@
-"""
-AI Agent 服务类 - 支持多模型动态切换
-接收外部传入的 checkpointer,不负责管理连接生命周期
-"""
-
-import os
-import json
-from dotenv import load_dotenv
-try:
- from langchain_community.chat_models import ChatZhipuAI
- HAS_ZHIPUAI = True
-except ImportError:
- HAS_ZHIPUAI = False
- ChatZhipuAI = None
-
-try:
- from langchain_openai import ChatOpenAI, OpenAIEmbeddings
- HAS_OPENAI = True
-except ImportError:
- HAS_OPENAI = False
- ChatOpenAI = None
- OpenAIEmbeddings = None
-
-from pydantic import SecretStr
-try:
- from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
- HAS_POSTGRES_CHECKPOINT = True
-except ImportError:
- HAS_POSTGRES_CHECKPOINT = False
- AsyncPostgresSaver = None
-
-# 本地模块
-from app.graph_builder import GraphBuilder, GraphContext
-from app.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
-try:
- from app.rag import RAGPipeline
- from app.rag.tools import RAGTool
- HAS_RAG = True
-except ImportError as e:
- HAS_RAG = False
- RAGPipeline = None
- RAGTool = None
-from app.logger import debug, info, warning, error
-
-
-load_dotenv()
-
-
-class AIAgentService:
- """异步 AI Agent 服务,支持多模型动态切换,使用外部传入的 checkpointer"""
-
- def __init__(self, checkpointer: AsyncPostgresSaver):
- """
- 初始化服务
- Args:
- checkpointer: 已经初始化的 AsyncPostgresSaver 实例
- """
- self.checkpointer = checkpointer
- self.graphs = {} # 存储不同模型对应的 graph 实例
- self.rag = None # RAG 检索实例
- self.rag_tool = None # RAG 工具实例
-
- def _create_zhipu_llm(self):
- """创建智谱在线 LLM"""
- if not HAS_ZHIPUAI:
- raise ImportError("智谱AI支持不可用,请安装langchain-community包")
- api_key = os.getenv("ZHIPUAI_API_KEY")
- if not api_key:
- raise ValueError("ZHIPUAI_API_KEY not set in environment")
- return ChatZhipuAI(
- model="glm-4.7-flash",
- api_key=api_key,
- temperature=0.1,
- max_tokens=4096,
- timeout=120.0, # 增加请求超时时间(秒),原为60秒
- max_retries=3, # 增加重试次数,原为2次
- streaming=True, # 确保开启流式输出
- )
-
- def _create_deepseek_llm(self):
- """创建 DeepSeek LLM(使用 OpenAI 兼容 API)"""
- if not HAS_OPENAI:
- raise ImportError("OpenAI兼容支持不可用,请安装langchain-openai包")
- api_key = os.getenv("DEEPSEEK_API_KEY")
- if not api_key:
- raise ValueError("DEEPSEEK_API_KEY not set in environment")
- return ChatOpenAI(
- base_url="https://api.deepseek.com",
- api_key=SecretStr(api_key),
- model="deepseek-reasoner", # deepseek-chat: 非思考模式, deepseek-reasoner: 思考模式
- temperature=0.1,
- max_tokens=4096,
- timeout=60.0, # 请求超时时间(秒)
- max_retries=2, # 失败后自动重试次数
- streaming=True, # 确保开启流式输出
- )
-
- def _create_local_llm(self):
- """创建本地 vLLM 服务 LLM"""
- if not HAS_OPENAI:
- raise ImportError("OpenAI兼容支持不可用,请安装langchain-openai包")
- # vLLM 服务地址:优先从环境变量读取,适配 Docker、FRP 穿透和本地开发
- vllm_base_url = os.getenv(
- "VLLM_BASE_URL",
- "http://127.0.0.1:8081/v1"
- )
-
- return ChatOpenAI(
- base_url=vllm_base_url,
- api_key=SecretStr(os.getenv("LLAMACPP_API_KEY", "token-abc123")),
- model="gemma-4-E2B-it",
- timeout=60.0, # 请求超时时间(秒)
- max_retries=2, # 失败后自动重试次数
- streaming=True, # 确保开启流式输出
- )
-
- def _create_embeddings(self):
- """创建嵌入模型"""
- if not HAS_OPENAI:
- raise ImportError("OpenAI兼容支持不可用,请安装langchain-openai包")
- embedding_url = os.getenv(
- "LLAMACPP_EMBEDDING_URL",
- "http://127.0.0.1:8082/v1"
- )
- return OpenAIEmbeddings(
- openai_api_base=embedding_url,
- openai_api_key=os.getenv("LLAMACPP_API_KEY", "token-abc123"),
- model="text-embedding-ada-002", # 模型名称不重要,兼容即可
- )
-
- async def initialize(self):
- """预编译所有模型的 graph(使用传入的 checkpointer)"""
- # 先初始化 RAG 检索系统
- if HAS_RAG and RAGPipeline is not None and RAGTool is not None:
- try:
- info("🔄 正在初始化 RAG 检索系统...")
- embeddings = self._create_embeddings()
- self.rag = RAGPipeline(embeddings=embeddings)
- self.rag_tool = RAGTool(self.rag).get_tool()
- info("✅ RAG 检索系统初始化成功")
- except Exception as e:
- warning(f"⚠️ RAG 检索系统初始化失败: {e}")
- self.rag = None
- self.rag_tool = None
- else:
- info("⏭️ RAG 检索系统不可用,跳过初始化")
- self.rag = None
- self.rag_tool = None
-
- model_configs = {
- "local": self._create_local_llm, # 本地模型作为第一个
- "deepseek": self._create_deepseek_llm, # DeepSeek 作为中间
- "zhipu": self._create_zhipu_llm, # GLM-4.7 作为最后一个
- }
-
- for model_name, llm_creator in model_configs.items():
- try:
- info(f"🔄 正在初始化模型 '{model_name}'...")
- llm = llm_creator()
-
- # 构建工具列表:基础工具 + RAG工具(如果可用)
- tools = AVAILABLE_TOOLS.copy()
- tools_by_name = TOOLS_BY_NAME.copy()
-
- if self.rag_tool is not None:
- tools.append(self.rag_tool)
- tools_by_name[self.rag_tool.name] = self.rag_tool
-
- builder = GraphBuilder(llm, tools, tools_by_name).build()
- graph = builder.compile(checkpointer=self.checkpointer)
- self.graphs[model_name] = graph
- info(f"✅ 模型 '{model_name}' 初始化成功")
- except Exception as e:
- import traceback
- error_detail = traceback.format_exc()
- warning(f"⚠️ 模型 '{model_name}' 初始化失败: {e}")
- debug(f" 详细错误:\n{error_detail}")
-
- if not self.graphs:
- raise RuntimeError("没有可用的模型,请检查配置。可能的原因:\n"
- "1. ZHIPUAI_API_KEY 未配置或无效\n"
- "2. DEEPSEEK_API_KEY 未配置或无效\n"
- "3. vLLM 服务未启动或地址错误 (VLLM_BASE_URL)\n"
- "4. 网络连接问题")
-
- return self
-
- async def process_message(self, message: str, thread_id: str, model: str = "local", user_id: str = "default_user") -> dict:
- """
- 处理用户消息,返回包含回复、token统计和耗时的字典
-
- Returns:
- dict: {
- "reply": str, # AI 回复内容
- "token_usage": dict, # Token 使用详情
- "elapsed_time": float # 调用耗时(秒)
- }
- """
- # 尝试使用指定模型,如果不可用则循环尝试其他模型
- if model not in self.graphs:
- warning(f"警告: 模型 '{model}' 不可用,尝试切换到其他可用模型")
- found = False
- for available_model in self.graphs.keys():
- try:
- # 这里可以添加额外的模型可用性检查逻辑
- model = available_model
- found = True
- info(f"已切换到可用模型: '{model}'")
- break
- except Exception as e:
- warning(f"模型 '{available_model}' 也不可用: {str(e)}")
- continue
-
- if not found:
- raise RuntimeError(f"错误: 没有任何可用的模型。当前注册的模型: {list(self.graphs.keys())}")
-
- graph = self.graphs[model]
- config = {
- "configurable": {"thread_id": thread_id},
- "metadata": {"user_id": user_id} # 写入 metadata 供历史查询使用
- }
- input_state = {"messages": [{"role": "user", "content": message}]}
- context = GraphContext(user_id=user_id)
-
- result = await graph.ainvoke(input_state, config=config, context=context)
-
- reply = result["messages"][-1].content
- token_usage = result.get("last_token_usage", {})
- elapsed_time = result.get("last_elapsed_time", 0.0)
-
- return {
- "reply": reply,
- "token_usage": token_usage,
- "elapsed_time": elapsed_time
- }
-
- def _serialize_value(self, value):
- """递归将 LangChain 对象转换为可 JSON 序列化的格式"""
- if hasattr(value, 'content'):
- # LangChain 消息对象
- msg_type = getattr(value, 'type', 'message')
- return {
- "role": msg_type,
- "content": getattr(value, 'content', ''),
- "additional_kwargs": getattr(value, 'additional_kwargs', {}),
- "tool_calls": getattr(value, 'tool_calls', [])
- }
- elif isinstance(value, dict):
- return {k: self._serialize_value(v) for k, v in value.items()}
- elif isinstance(value, (list, tuple)):
- return [self._serialize_value(item) for item in value]
- else:
- try:
- json.dumps(value)
- return value
- except (TypeError, ValueError):
- return str(value)
-
- async def process_message_stream(self, message: str, thread_id: str, model_name: str, user_id: str = "default_user"):
- """
- 流式处理消息,返回异步生成器
-
- Args:
- message: 用户消息
- thread_id: 线程 ID
- model_name: 模型名称
- user_id: 用户 ID
-
- Yields:
- 字典,包含事件类型和数据
- """
- graph = self.graphs.get(model_name)
-
- if not graph:
- raise ValueError(f"模型 '{model_name}' 未找到或未初始化")
-
- config = {
- "configurable": {"thread_id": thread_id},
- "metadata": {"user_id": user_id}
- }
- input_state = {"messages": [{"role": "user", "content": message}]}
- context = GraphContext(user_id=user_id)
-
- async for chunk in graph.astream(
- input_state,
- config=config,
- context=context,
- stream_mode=["messages", "updates", "custom"], # 组合多种模式,添加 custom
- version="v2", # 使用统一的v2格式
- subgraphs=True # 如果你使用了子图,请开启此项
- ):
- chunk_type = chunk["type"]
- processed_event = {}
-
- # 1. 处理 LLM Token 流 (实现打字机效果)
- if chunk_type == "messages":
- message_chunk, metadata = chunk["data"]
-
- # 提取元数据
- node_name = metadata.get("langgraph_node", "unknown")
- # 使用 getattr 安全地获取内容,因为 message_chunk 可能不是字符串
- token_content = getattr(message_chunk, 'content', str(message_chunk))
-
- # 提取 DeepSeek reasoner 的思考过程 token
- reasoning_token = ""
- if hasattr(message_chunk, 'additional_kwargs'):
- reasoning_token = message_chunk.additional_kwargs.get("reasoning_content", "")
-
- # [DEBUG] 临时添加:只在 reasoning_token 不为空时打印,方便你直观地看到它
- if reasoning_token:
- import logging
- logging.debug(f"💡 [Reasoning Token 捕获]: {repr(reasoning_token)}")
-
- processed_event = {
- "type": "llm_token",
- "node": node_name,
- "token": token_content,
- "reasoning_token": reasoning_token,
- "metadata": metadata # 可选的元数据
- }
-
- # 2. 处理状态更新 (节点执行完成)
- elif chunk_type == "updates":
- updates_data = chunk["data"]
- # 序列化 updates 中的所有数据
- serialized_data = self._serialize_value(updates_data)
- processed_event = {
- "type": "state_update",
- "data": serialized_data
- }
- # 为了兼容前端旧字段,也保留 messages 字段(可选)
- if "messages" in serialized_data:
- processed_event["messages"] = serialized_data["messages"]
-
- # 3. 处理自定义数据 (如果需要)
- elif chunk_type == "custom":
- # 自定义事件同样需要序列化
- serialized_data = self._serialize_value(chunk["data"])
- processed_event = {
- "type": "custom",
- "data": serialized_data
- }
-
- # 4. 其他类型(debug, tasks等)按需处理
- else:
- # 对于不需要的类型,直接跳过
- continue
-
- # 确保事件有数据再发送
- if processed_event:
- yield processed_event
\ No newline at end of file
diff --git a/app/agent/agent.py b/app/agent/agent.py
new file mode 100644
index 0000000..cefb43f
--- /dev/null
+++ b/app/agent/agent.py
@@ -0,0 +1,166 @@
+"""
+AI Agent 服务类 - 支持多模型动态切换
+接收外部传入的 checkpointer,不负责管理连接生命周期
+"""
+
+import os
+import json
+from dotenv import load_dotenv
+
+from langchain_community.chat_models import ChatZhipuAI
+from langchain_openai import ChatOpenAI, OpenAIEmbeddings
+from pydantic import SecretStr
+from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
+
+# 本地模块
+from app.graph_builder import GraphBuilder, GraphContext
+from app.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
+from app.rag import RAGPipeline
+from app.rag.tools import create_rag_tool_sync
+from rag_core import create_parent_retriever
+from app.llm_factory import LLMFactory
+from app.rag_initializer import init_rag_tool
+from app.logger import debug, info, warning, error
+load_dotenv()
+
+
+class AIAgentService:
+ def __init__(self, checkpointer):
+ self.checkpointer = checkpointer
+ self.graphs = {}
+ self.tools = AVAILABLE_TOOLS.copy()
+ self.tools_by_name = TOOLS_BY_NAME.copy()
+
+ async def initialize(self):
+ # 1. 初始化 RAG 工具(如果需要)
+ rag_tool = await init_rag_tool(LLMFactory.create_local)
+ if rag_tool:
+ self.tools.append(rag_tool)
+ self.tools_by_name[rag_tool.name] = rag_tool
+
+ # 2. 构建各模型的 Graph
+ for name, creator in LLMFactory.CREATORS.items():
+ try:
+ info(f"🔄 初始化模型 '{name}'...")
+ llm = creator()
+ builder = GraphBuilder(llm, self.tools, self.tools_by_name).build()
+ graph = builder.compile(checkpointer=self.checkpointer)
+ self.graphs[name] = graph
+ info(f"✅ 模型 '{name}' 初始化成功")
+ except Exception as e:
+ warning(f"⚠️ 模型 '{name}' 初始化失败: {e}")
+
+ if not self.graphs:
+ raise RuntimeError("没有可用的模型")
+ return self
+
+ async def process_message(self, message: str, thread_id: str, model: str = "local", user_id: str = "default_user") -> dict:
+ """处理用户消息,返回包含回复、token统计和耗时的字典"""
+ if model not in self.graphs:
+ # 回退到第一个可用模型
+ available = list(self.graphs.keys())
+ if not available:
+ raise RuntimeError("没有可用的模型")
+ model = available[0]
+ warning(f"模型 '{model}' 不可用,已回退到 '{model}'")
+
+ graph = self.graphs[model]
+ config = {
+ "configurable": {"thread_id": thread_id},
+ "metadata": {"user_id": user_id}
+ }
+ input_state = {"messages": [{"role": "user", "content": message}]}
+ context = GraphContext(user_id=user_id)
+
+ result = await graph.ainvoke(input_state, config=config, context=context)
+
+ reply = result["messages"][-1].content
+ token_usage = result.get("last_token_usage", {})
+ elapsed_time = result.get("last_elapsed_time", 0.0)
+
+ return {
+ "reply": reply,
+ "token_usage": token_usage,
+ "elapsed_time": elapsed_time
+ }
+
+ def _serialize_value(self, value):
+ """递归将 LangChain 对象转换为可 JSON 序列化的格式"""
+ if hasattr(value, 'content'):
+ msg_type = getattr(value, 'type', 'message')
+ return {
+ "role": msg_type,
+ "content": getattr(value, 'content', ''),
+ "additional_kwargs": getattr(value, 'additional_kwargs', {}),
+ "tool_calls": getattr(value, 'tool_calls', [])
+ }
+ elif isinstance(value, dict):
+ return {k: self._serialize_value(v) for k, v in value.items()}
+ elif isinstance(value, (list, tuple)):
+ return [self._serialize_value(item) for item in value]
+ else:
+ try:
+ json.dumps(value)
+ return value
+ except (TypeError, ValueError):
+ return str(value)
+
+ async def process_message_stream(self, message: str, thread_id: str, model_name: str, user_id: str = "default_user"):
+ """流式处理消息,返回异步生成器"""
+ graph = self.graphs.get(model_name)
+ if not graph:
+ raise ValueError(f"模型 '{model_name}' 未找到或未初始化")
+
+ config = {
+ "configurable": {"thread_id": thread_id},
+ "metadata": {"user_id": user_id}
+ }
+ input_state = {"messages": [{"role": "user", "content": message}]}
+ context = GraphContext(user_id=user_id)
+
+ async for chunk in graph.astream(
+ input_state,
+ config=config,
+ context=context,
+ stream_mode=["messages", "updates", "custom"],
+ version="v2",
+ subgraphs=True
+ ):
+ chunk_type = chunk["type"]
+ processed_event = {}
+
+ if chunk_type == "messages":
+ message_chunk, metadata = chunk["data"]
+ node_name = metadata.get("langgraph_node", "unknown")
+ token_content = getattr(message_chunk, 'content', str(message_chunk))
+ reasoning_token = ""
+ if hasattr(message_chunk, 'additional_kwargs'):
+ reasoning_token = message_chunk.additional_kwargs.get("reasoning_content", "")
+
+ processed_event = {
+ "type": "llm_token",
+ "node": node_name,
+ "token": token_content,
+ "reasoning_token": reasoning_token,
+ "metadata": metadata
+ }
+ elif chunk_type == "updates":
+ updates_data = chunk["data"]
+ serialized_data = self._serialize_value(updates_data)
+ processed_event = {
+ "type": "state_update",
+ "data": serialized_data
+ }
+ if "messages" in serialized_data:
+ processed_event["messages"] = serialized_data["messages"]
+ elif chunk_type == "custom":
+ serialized_data = self._serialize_value(chunk["data"])
+ processed_event = {
+ "type": "custom",
+ "data": serialized_data
+ }
+ else:
+ continue
+
+ if processed_event:
+ yield processed_event
\ No newline at end of file
diff --git a/app/history.py b/app/agent/history.py
similarity index 100%
rename from app/history.py
rename to app/agent/history.py
diff --git a/app/agent/llm_factory.py b/app/agent/llm_factory.py
new file mode 100644
index 0000000..9a1a22a
--- /dev/null
+++ b/app/agent/llm_factory.py
@@ -0,0 +1,56 @@
+# app/llm_factory.py
+import os
+from langchain_community.chat_models import ChatZhipuAI
+from langchain_openai import ChatOpenAI
+from pydantic import SecretStr
+
+class LLMFactory:
+ @staticmethod
+ def create_zhipu():
+ api_key = os.getenv("ZHIPUAI_API_KEY")
+ if not api_key:
+ raise ValueError("ZHIPUAI_API_KEY not set")
+ return ChatZhipuAI(
+ model="glm-4.7-flash",
+ api_key=api_key,
+ temperature=0.1,
+ max_tokens=4096,
+ timeout=120.0,
+ max_retries=3,
+ streaming=True,
+ )
+
+ @staticmethod
+ def create_deepseek():
+ api_key = os.getenv("DEEPSEEK_API_KEY")
+ if not api_key:
+ raise ValueError("DEEPSEEK_API_KEY not set")
+ return ChatOpenAI(
+ base_url="https://api.deepseek.com",
+ api_key=SecretStr(api_key),
+ model="deepseek-reasoner",
+ temperature=0.1,
+ max_tokens=4096,
+ timeout=60.0,
+ max_retries=2,
+ streaming=True,
+ )
+
+ @staticmethod
+ def create_local():
+ base_url = os.getenv("VLLM_BASE_URL", "http://127.0.0.1:8081/v1")
+ return ChatOpenAI(
+ base_url=base_url,
+ api_key=SecretStr(os.getenv("LLAMACPP_API_KEY", "token-abc123")),
+ model="gemma-4-E4B-it",
+ timeout=60.0,
+ max_retries=2,
+ streaming=True,
+ )
+
+ # 模型创建器映射
+ CREATORS = {
+ "local": create_local,
+ "deepseek": create_deepseek,
+ "zhipu": create_zhipu,
+ }
\ No newline at end of file
diff --git a/app/prompts.py b/app/agent/prompts.py
similarity index 62%
rename from app/prompts.py
rename to app/agent/prompts.py
index f0f2c74..990e634 100644
--- a/app/prompts.py
+++ b/app/agent/prompts.py
@@ -1,18 +1,21 @@
-"""
-提示模板管理模块
-所有系统提示和对话模板统一定义
-"""
-
+# app/prompts.py
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
+from langchain_core.tools import BaseTool
+def create_system_prompt(tools: list = None) -> ChatPromptTemplate:
+ """
+ 创建系统提示模板,可选择动态注入工具描述。
+ """
+ tools_section = ""
+ if tools:
+ tool_descs = []
+ for tool in tools:
+ # 提取工具名称和描述的第一行
+ name = getattr(tool, 'name', tool.__name__)
+ desc = (tool.description or "").split('\n')[0]
+ tool_descs.append(f"- {name}: {desc}")
+ tools_section = "\n".join(tool_descs)
-def create_system_prompt() -> ChatPromptTemplate:
- """
- 创建系统提示模板
-
- Returns:
- ChatPromptTemplate: 包含系统指令和消息占位符的提示模板
- """
system_template = (
"你是一个个人生活助手和数据分析助手,请使用中文交流。\n\n"
"【用户背景信息】\n"
@@ -20,15 +23,11 @@ def create_system_prompt() -> ChatPromptTemplate:
"{memory_context}\n"
"若包含姓名、偏好等个人信息,请自然融入回应(例如称呼名字、提及偏好)。\n\n"
"【可用工具与使用规则】\n"
- "- 获取温度/天气:`get_current_temperature`\n"
- "- 读取文本文件:`read_local_file`(限定目录 `./user_docs`)\n"
- "- 读取PDF摘要:`read_pdf_summary`(限定目录 `./user_docs`)\n"
- "- 读取Excel表格:`read_excel_as_markdown`(限定目录 `./user_docs`)\n"
- "- 抓取网页内容:`fetch_webpage_content`\n"
+ f"{tools_section}\n"
"工具调用时请直接返回所需参数,无需额外说明。\n\n"
"【回答要求(必须遵守)】\n"
"1. 回答必须简洁、直接。\n"
- "2. 如果你认为该问题需要进行深入的推理或思考,请务必将你的思维链或推理过程用 `` 和 `` 标签包裹起来,放在回答的最前面。例如:这里是我的思考过程...这里是最终回答。\n"
+ "2. 如果你认为该问题需要进行深入的推理或思考,请务必将你的思维链或推理过程用 `` 和 `` 标签包裹起来,放在回答的最前面。\n"
"3. 优先利用已知用户信息进行个性化回复。\n"
"4. 若无信息可依,礼貌询问或提供通用帮助。"
)
@@ -36,4 +35,4 @@ def create_system_prompt() -> ChatPromptTemplate:
return ChatPromptTemplate.from_messages([
("system", system_template),
MessagesPlaceholder(variable_name="messages")
- ])
+ ])
\ No newline at end of file
diff --git a/app/agent/rag_initializer.py b/app/agent/rag_initializer.py
new file mode 100644
index 0000000..f391b8f
--- /dev/null
+++ b/app/agent/rag_initializer.py
@@ -0,0 +1,23 @@
+# app/rag_initializer.py
+from app.rag.tools import create_rag_tool_sync
+from rag_core import create_parent_retriever
+from app.logger import info, warning
+
+async def init_rag_tool(local_llm_creator):
+ """初始化 RAG 工具,失败返回 None"""
+ try:
+ info("🔄 正在初始化 RAG 检索系统...")
+ retriever = create_parent_retriever(
+ collection_name="rag_documents",
+ search_k=5,
+ )
+ rewrite_llm = local_llm_creator()
+ rag_tool = create_rag_tool_sync(
+ retriever, rewrite_llm,
+ num_queries=3, rerank_top_n=5
+ )
+ info("✅ RAG 检索工具初始化成功")
+ return rag_tool
+ except Exception as e:
+ warning(f"⚠️ RAG 检索工具初始化失败: {e}")
+ return None
\ No newline at end of file
diff --git a/app/backend.py b/app/backend.py
index f531a36..20da307 100644
--- a/app/backend.py
+++ b/app/backend.py
@@ -231,6 +231,6 @@ async def websocket_endpoint(
if __name__ == "__main__":
import uvicorn
- # 使用环境变量或默认端口 8083(避免与 llama.cpp 的 8081 端口冲突)
- port = int(os.getenv("BACKEND_PORT", "8083"))
+ # 使用环境变量或默认端口 8079(避免与 llama.cpp 的 8081 端口冲突)
+ port = int(os.getenv("BACKEND_PORT", "8079"))
uvicorn.run(app, host="0.0.0.0", port=port)
diff --git a/app/graph/graph_builder.py b/app/graph/graph_builder.py
new file mode 100644
index 0000000..11f9c9d
--- /dev/null
+++ b/app/graph/graph_builder.py
@@ -0,0 +1,79 @@
+"""
+LangGraph 状态图构建模块 - 精简版,仅负责组装图
+所有节点逻辑已拆分到独立模块
+"""
+
+from langchain_core.language_models import BaseLLM
+from langgraph.graph import StateGraph, START, END
+
+# 本地模块
+from app.graph.state import MessagesState, GraphContext
+from app.nodes import (
+ create_llm_call_node,
+ create_tool_call_node,
+ create_retrieve_memory_node,
+ create_summarize_node,
+ should_continue
+)
+from app.memory import Mem0Client
+from app.nodes.finalize import finalize_node
+
+
+class GraphBuilder:
+ """LangGraph 状态图构建器 - 仅负责组装图"""
+
+ def __init__(self, llm: BaseLLM, tools: list, tools_by_name: dict):
+ """
+ 初始化构建器
+
+ Args:
+ llm: 大语言模型实例
+ tools: 工具列表
+ tools_by_name: 名称到工具函数的映射
+ """
+ self.llm = llm
+ self.tools = tools
+ self.tools_by_name = tools_by_name
+
+ # ⭐ 创建 Mem0 客户端(懒加载,首次使用时初始化)
+ self.mem0_client = Mem0Client(llm)
+
+ def build(self) -> StateGraph:
+ """
+ 构建未编译的状态图
+
+ Returns:
+ StateGraph 实例
+ """
+ builder = StateGraph(MessagesState, context_schema=GraphContext)
+
+ # ⭐ 通过工厂函数创建节点(依赖注入)
+ retrieve_memory_node = create_retrieve_memory_node(self.mem0_client)
+ llm_call_node = create_llm_call_node(self.llm, self.tools)
+ tool_call_node = create_tool_call_node(self.tools_by_name)
+ summarize_node = create_summarize_node(self.mem0_client)
+
+ # 添加节点
+ builder.add_node("retrieve_memory", retrieve_memory_node)
+ builder.add_node("llm_call", llm_call_node)
+ builder.add_node("tool_node", tool_call_node)
+ builder.add_node("summarize", summarize_node)
+ builder.add_node("finalize", finalize_node)
+
+ # 添加边
+ builder.add_edge(START, "retrieve_memory")
+ builder.add_edge("retrieve_memory", "llm_call")
+ builder.add_conditional_edges(
+ "llm_call",
+ should_continue,
+ {
+ "tool_node": "tool_node",
+ "summarize": "summarize",
+ "finalize": "finalize"
+ }
+ )
+ builder.add_edge("tool_node", "llm_call")
+ builder.add_edge("summarize", "finalize")
+ builder.add_edge("finalize", END)
+
+ return builder
\ No newline at end of file
diff --git a/app/graph_tools.py b/app/graph/graph_tools.py
similarity index 100%
rename from app/graph_tools.py
rename to app/graph/graph_tools.py
diff --git a/app/nodes/retrieve_memory.py b/app/graph/retrieve_memory.py
similarity index 97%
rename from app/nodes/retrieve_memory.py
rename to app/graph/retrieve_memory.py
index 0313ca8..61434d0 100644
--- a/app/nodes/retrieve_memory.py
+++ b/app/graph/retrieve_memory.py
@@ -7,7 +7,7 @@ from typing import Any, Dict
from langgraph.runtime import Runtime
# 本地模块
-from app.state import MessagesState, GraphContext
+from app.graph.state import MessagesState, GraphContext
from app.memory.mem0_client import Mem0Client
from app.utils.logging import log_state_change
from app.logger import debug
diff --git a/app/state.py b/app/graph/state.py
similarity index 100%
rename from app/state.py
rename to app/graph/state.py
diff --git a/app/graph_builder.py b/app/graph_builder.py
index 2af230a..11f9c9d 100644
--- a/app/graph_builder.py
+++ b/app/graph_builder.py
@@ -7,7 +7,7 @@ from langchain_core.language_models import BaseLLM
from langgraph.graph import StateGraph, START, END
# 本地模块
-from app.state import MessagesState, GraphContext
+from app.graph.state import MessagesState, GraphContext
from app.nodes import (
create_llm_call_node,
create_tool_call_node,
diff --git a/app/nodes/__init__.py b/app/nodes/__init__.py
index 8d279db..d9eb644 100644
--- a/app/nodes/__init__.py
+++ b/app/nodes/__init__.py
@@ -5,7 +5,7 @@
from app.nodes.router import should_continue
from app.nodes.llm_call import create_llm_call_node
from app.nodes.tool_call import create_tool_call_node
-from app.nodes.retrieve_memory import create_retrieve_memory_node
+from app.graph.retrieve_memory import create_retrieve_memory_node
from app.nodes.summarize import create_summarize_node
from app.nodes.finalize import finalize_node
diff --git a/app/nodes/finalize.py b/app/nodes/finalize.py
index a283587..ea6e1dc 100644
--- a/app/nodes/finalize.py
+++ b/app/nodes/finalize.py
@@ -8,7 +8,7 @@ from langgraph.runtime import Runtime
from langgraph.config import get_stream_writer
# 本地模块
-from app.state import MessagesState, GraphContext
+from app.graph.state import MessagesState, GraphContext
from app.utils.logging import log_state_change
from app.logger import info, error
diff --git a/app/nodes/llm_call.py b/app/nodes/llm_call.py
index 79361cd..22abfed 100644
--- a/app/nodes/llm_call.py
+++ b/app/nodes/llm_call.py
@@ -12,7 +12,7 @@ from langchain_core.runnables import RunnableLambda
from langgraph.runtime import Runtime
# 本地模块
-from app.state import MessagesState, GraphContext
+from app.graph.state import MessagesState, GraphContext
from app.prompts import create_system_prompt
from app.utils.logging import log_state_change, print_llm_input
from app.logger import debug, info, error
@@ -30,7 +30,7 @@ def create_llm_call_node(llm: BaseLLM, tools: list):
异步节点函数
"""
# 构建调用链
- prompt = create_system_prompt()
+ prompt = create_system_prompt(tools)
llm_with_tools = llm.bind_tools(tools)
# 恢复带 RunnableLambda 的链,并在下方使用 astream 遍历
diff --git a/app/nodes/router.py b/app/nodes/router.py
index 81d2e7e..cabc275 100644
--- a/app/nodes/router.py
+++ b/app/nodes/router.py
@@ -8,7 +8,7 @@ from langchain_core.messages import AIMessage
# 本地模块
from app.config import ENABLE_GRAPH_TRACE, MEMORY_SUMMARIZE_INTERVAL
-from app.state import MessagesState
+from app.graph.state import MessagesState
from app.logger import info
diff --git a/app/nodes/summarize.py b/app/nodes/summarize.py
index a49742a..8b39baa 100644
--- a/app/nodes/summarize.py
+++ b/app/nodes/summarize.py
@@ -7,7 +7,7 @@ from typing import Any, Dict
from langgraph.runtime import Runtime
# 本地模块
-from app.state import MessagesState, GraphContext
+from app.graph.state import MessagesState, GraphContext
from app.memory.mem0_client import Mem0Client
from app.utils.logging import log_state_change
from app.logger import debug, info, error, warning
diff --git a/app/nodes/tool_call.py b/app/nodes/tool_call.py
index 12648ff..1c9d55f 100644
--- a/app/nodes/tool_call.py
+++ b/app/nodes/tool_call.py
@@ -10,7 +10,7 @@ from langgraph.runtime import Runtime
from langgraph.config import get_stream_writer
# 本地模块
-from app.state import MessagesState, GraphContext
+from app.graph.state import MessagesState, GraphContext
from app.utils.logging import log_state_change
from app.logger import debug, info
diff --git a/app/rag/__init__.py b/app/rag/__init__.py
index 8b4868f..1438604 100644
--- a/app/rag/__init__.py
+++ b/app/rag/__init__.py
@@ -39,7 +39,7 @@ from .retriever import (
create_hybrid_retriever,
create_qdrant_client,
)
-from .reranker import CrossEncoderReranker
+from .reranker import LLaMaCPPReranker
from .query_transform import MultiQueryGenerator
from .fusion import reciprocal_rank_fusion
from .pipeline import RAGPipeline
@@ -53,7 +53,7 @@ __all__ = [
"create_qdrant_client",
# 重排序器
- "CrossEncoderReranker",
+ "LLaMaCPPReranker",
# 查询改写生成器
"MultiQueryGenerator",
diff --git a/app/rag/pipeline.py b/app/rag/pipeline.py
index c0b4e6f..c8d95d6 100644
--- a/app/rag/pipeline.py
+++ b/app/rag/pipeline.py
@@ -1,6 +1,7 @@
# rag/pipeline.py
import asyncio
+import os
from typing import List, Optional
from langchain_core.documents import Document
from langchain_core.language_models import BaseLanguageModel
@@ -23,7 +24,6 @@ class RAGPipeline:
llm: BaseLanguageModel,
num_queries: int = 3,
rerank_top_n: int = 5,
- rerank_model: str = "BAAI/bge-reranker-base",
):
"""
Args:
@@ -41,9 +41,9 @@ class RAGPipeline:
# 初始化组件
self.query_generator = MultiQueryGenerator(llm=llm, num_queries=num_queries)
self.reranker = LLaMaCPPReranker(
- base_url="http://127.0.0.1:8083",
+ base_url=os.getenv("LLAMACPP_RERANKER_URL", "http://127.0.0.1:8083"),
+ api_key=os.getenv("LLAMACPP_API_KEY", "huang1998"),
top_n=rerank_top_n,
- api_key="huang1998"
)
async def aretrieve(self, query: str) -> List[Document]:
@@ -68,9 +68,9 @@ class RAGPipeline:
fused_docs = reciprocal_rank_fusion(doc_lists)
# Step 4: 重排序
- if self.reranker.model is not None:
+ try:
final_docs = self.reranker.compress_documents(fused_docs, query)
- else:
+ except Exception:
# 若重排序器不可用,直接返回融合后的前 N 条
final_docs = fused_docs[:self.rerank_top_n]
diff --git a/app/rag/reranker.py b/app/rag/reranker.py
index 7a53806..b6f7d4a 100644
--- a/app/rag/reranker.py
+++ b/app/rag/reranker.py
@@ -2,32 +2,33 @@
重排序器模块 (适配版)
使用远程 llama.cpp 服务 (兼容 OpenAI Rerank API) 替代本地 Cross-Encoder
"""
+import os
import requests
-from typing import List
+from typing import List, Optional
from langchain_core.documents import Document
class LLaMaCPPReranker:
"""使用远程 llama.cpp 服务对检索结果重排序。"""
def __init__(self,
- base_url: str = "http://127.0.0.1:8083",
+ base_url: str,
+ api_key: str,
top_n: int = 5,
- api_key: str = "huang1998", # 你设置的 LLAMA_ARG_API_KEY
timeout: int = 60):
"""
初始化远程重排序器
Args:
- base_url: llama.cpp 服务的地址和端口。
+ base_url: llama.cpp 服务的地址和端口,默认为环境变量 LLAMACPP_RERANKER_URL 或 "http://127.0.0.1:8083"。
top_n: 返回前 N 个结果。
- api_key: 在容器中设置的 API 密钥。
+ api_key: API 密钥,默认为环境变量 LLAMACPP_API_KEY 或 "huang1998"。
timeout: 请求超时时间(秒)。
"""
- self.base_url = base_url.rstrip('/')
+ self.base_url = base_url
+ self.api_key = api_key
self.top_n = top_n
- self.api_key = api_key
self.timeout = timeout
- self.endpoint = f"{self.base_url}/v1/rerank"
+ self.endpoint = f"{self.base_url}/rerank"
def compress_documents(
self, documents: List[Document], query: str
diff --git a/app/rag/tools.py b/app/rag/tools.py
index 32268ed..4934101 100644
--- a/app/rag/tools.py
+++ b/app/rag/tools.py
@@ -4,74 +4,12 @@ RAG 工具模块
将检索功能封装为 LangChain Tool,供 Agent 调用。
采用固定流水线:多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档。
"""
-
from typing import Optional, Callable
from langchain_core.tools import tool
from langchain_core.language_models import BaseLanguageModel
from langchain_core.retrievers import BaseRetriever
-
from .pipeline import RAGPipeline
-
-def create_rag_tool(
- retriever: BaseRetriever,
- llm: BaseLanguageModel,
- num_queries: int = 3,
- rerank_top_n: int = 5,
- collection_name: str = "rag_documents",
-) -> Callable:
- """
- 创建一个配置好的 RAG 检索工具(异步)。
-
- Args:
- retriever: 基础检索器(例如 ParentDocumentRetriever 实例)
- llm: 用于多路查询改写的语言模型
- num_queries: 生成查询变体数量
- rerank_top_n: 最终返回的文档数量
- collection_name: 集合名称(仅用于日志/描述)
-
- Returns:
- LangChain Tool 可调用对象(异步)
- """
- # 初始化流水线(所有组件一次创建,后续复用)
- pipeline = RAGPipeline(
- retriever=retriever,
- llm=llm,
- num_queries=num_queries,
- rerank_top_n=rerank_top_n,
- )
-
- @tool
- async def search_knowledge_base(query: str) -> str:
- """在知识库中搜索与查询相关的文档片段。
-
- 该工具会:
- 1. 将用户问题改写成多个不同角度的查询
- 2. 并行检索每个查询的相关父文档
- 3. 使用倒数排名融合(RRF)合并结果
- 4. 用 Cross-Encoder 重排序模型精选最相关的片段
-
- 适用于需要精确、全面答案的事实性问题或背景知识查询。
-
- Args:
- query: 用户提出的问题或查询字符串
-
- Returns:
- 格式化后的相关文档内容,若无结果则返回提示信息。
- """
- try:
- documents = await pipeline.aretrieve(query)
- if not documents:
- return f"在知识库 '{collection_name}' 中未找到与 '{query}' 相关的信息。"
-
- context = pipeline.format_context(documents)
- return context
- except Exception as e:
- return f"检索过程中发生错误: {str(e)}"
-
- return search_knowledge_base
-
-
def create_rag_tool_sync(
retriever: BaseRetriever,
llm: BaseLanguageModel,
diff --git a/frontend/config.py b/frontend/config.py
index b9b8c9d..da4c995 100644
--- a/frontend/config.py
+++ b/frontend/config.py
@@ -5,6 +5,7 @@
import os
from dataclasses import dataclass
+from typing import Optional
from dotenv import load_dotenv
# 加载 .env 文件
@@ -25,7 +26,7 @@ class FrontendConfig:
# ==================== 模型配置 ====================
default_model: str = "local" # 更改为local作为默认模型
- model_options: dict = None
+ model_options: Optional[dict] = None
# ==================== 用户配置 ====================
default_user_id: str = "default_user"
@@ -53,7 +54,7 @@ class FrontendConfig:
"""从环境变量加载配置(优先级最高)"""
# API 地址(移除 /chat 后缀)
# 优先级:环境变量 API_URL > 默认值
- api_url = os.getenv("API_URL", "http://127.0.0.1:8083")
+ api_url = os.getenv("API_URL", "http://127.0.0.1:8079")
self.api_base = api_url.replace("/chat", "").rstrip("/")
diff --git a/rag_core/client.py b/rag_core/client.py
index 3f313ca..109958a 100644
--- a/rag_core/client.py
+++ b/rag_core/client.py
@@ -9,16 +9,19 @@ QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
def create_qdrant_client(
url: Optional[str] = None,
api_key: Optional[str] = None,
- timeout: int = 120, # 索引构建需要较长超时
+ timeout: int = 300, # 索引构建需要较长超时
) -> QdrantClient:
effective_url = url or QDRANT_URL
effective_api_key = api_key or QDRANT_API_KEY
-
+
if not effective_url:
raise ValueError("Qdrant URL 未配置")
-
- client_kwargs = {"url": effective_url, "timeout": timeout}
+
+ client_kwargs = {
+ "url": effective_url,
+ "timeout": timeout,
+ }
if effective_api_key:
client_kwargs["api_key"] = effective_api_key
-
+
return QdrantClient(**client_kwargs)
\ No newline at end of file
diff --git a/rag_core/vector_store.py b/rag_core/vector_store.py
index 7fd3080..13cbfbb 100644
--- a/rag_core/vector_store.py
+++ b/rag_core/vector_store.py
@@ -4,12 +4,15 @@ Qdrant 向量数据库包装器。
import logging
import os
+import time
from typing import List, Optional, Dict, Any
from langchain_core.documents import Document
from langchain_qdrant import QdrantVectorStore as LangchainQdrantVS
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams
+from httpx import RemoteProtocolError
+from qdrant_client.http.exceptions import ResponseHandlingException
from .client import create_qdrant_client
logger = logging.getLogger(__name__)
@@ -28,6 +31,8 @@ class QdrantVectorStore:
):
self.collection_name = collection_name
self._client: Optional[QdrantClient] = None
+ self._connection_attempts = 0
+ self._last_connection_time: Optional[float] = None
if embeddings is None:
from .embedders import LlamaCppEmbedder
@@ -46,14 +51,47 @@ class QdrantVectorStore:
def get_client(self) -> QdrantClient:
if self._client is None:
- self._client = create_qdrant_client(timeout=120)
+ self._client = create_qdrant_client(timeout=300)
+ self._connection_attempts += 1
+ self._last_connection_time = time.time()
+ logger.debug("Qdrant 客户端已创建 (第 %d 次连接)", self._connection_attempts)
return self._client
def refresh_client(self):
"""关闭旧连接,创建新连接。"""
if self._client is not None:
- self._client.close()
- self._client = None
+ try:
+ self._client.close()
+ logger.debug("Qdrant 旧连接已关闭")
+ except Exception as e:
+ logger.warning("关闭 Qdrant 连接时出现异常: %s", e)
+ finally:
+ self._client = None
+ self._last_connection_time = None
+
+ def check_connection_health(self) -> bool:
+ """检查连接健康状态,如果连接已失效则自动重建。"""
+ if self._client is None:
+ logger.info("Qdrant 客户端未初始化,将创建新连接")
+ return False
+
+ try:
+ client = self.get_client()
+ client.get_collections()
+ logger.debug("Qdrant 连接健康检查通过")
+ return True
+ except (RemoteProtocolError, ConnectionError, OSError, ResponseHandlingException) as e:
+ logger.warning("Qdrant 连接健康检查失败: %s", e)
+ self.refresh_client()
+ return False
+
+ def get_connection_stats(self) -> Dict[str, Any]:
+ """获取连接统计信息。"""
+ return {
+ "connection_attempts": self._connection_attempts,
+ "last_connection_time": self._last_connection_time,
+ "client_initialized": self._client is not None,
+ }
def create_collection(self, vector_size: Optional[int] = None, force_recreate: bool = False):
"""创建集合,设置合适的向量维度。"""
@@ -62,22 +100,40 @@ class QdrantVectorStore:
embedder = LlamaCppEmbedder()
vector_size = embedder.get_embedding_dimension()
- client = self.get_client()
- collections = client.get_collections().collections
- exists = any(c.name == self.collection_name for c in collections)
+ max_retries = 3
+ base_delay = 2
+ for attempt in range(max_retries):
+ try:
+ client = self.get_client()
+ collections = client.get_collections().collections
+ exists = any(c.name == self.collection_name for c in collections)
- if exists and force_recreate:
- client.delete_collection(self.collection_name)
- exists = False
+ if exists and force_recreate:
+ client.delete_collection(self.collection_name)
+ exists = False
- if not exists:
- client.create_collection(
- collection_name=self.collection_name,
- vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE),
- )
- logger.info("集合 '%s' 已创建(维度=%d)", self.collection_name, vector_size)
- else:
- logger.info("集合 '%s' 已存在", self.collection_name)
+ if not exists:
+ client.create_collection(
+ collection_name=self.collection_name,
+ vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE),
+ )
+ logger.info("集合 '%s' 已创建(维度=%d)", self.collection_name, vector_size)
+ else:
+ logger.info("集合 '%s' 已存在", self.collection_name)
+ return
+ except (RemoteProtocolError, ConnectionError, OSError, ResponseHandlingException) as e:
+ if attempt == max_retries - 1:
+ logger.error("创建集合 '%s' 重试 %d 次后仍然失败: %s", self.collection_name, max_retries, e)
+ raise
+ wait_time = base_delay * (2 ** attempt)
+ error_type = type(e).__name__
+ logger.warning(
+ "创建集合 '%s' 遇到网络异常 [%s],%d秒后重试 (%d/%d): %s",
+ self.collection_name, error_type, wait_time, attempt + 1, max_retries, e
+ )
+ self.refresh_client()
+ logger.debug("已刷新 Qdrant 客户端连接")
+ time.sleep(wait_time)
def add_documents(self, documents: List[Document], batch_size: int = 100):
"""将文档添加到向量数据库。"""
@@ -102,9 +158,10 @@ class QdrantVectorStore:
info = self.get_client().get_collection(self.collection_name)
vectors_config = info.config.params.vectors
if isinstance(vectors_config, dict):
- vector_size = next(iter(vectors_config.values())).size
+ first_config = next(iter(vectors_config.values()), None)
+ vector_size = first_config.size if first_config else 0
else:
- vector_size = vectors_config.size
+ vector_size = vectors_config.size if vectors_config else 0
return {
"name": self.collection_name,
"vectors_count": info.points_count or 0,
diff --git a/rag_indexer/index_builder.py b/rag_indexer/index_builder.py
index a585970..137a674 100644
--- a/rag_indexer/index_builder.py
+++ b/rag_indexer/index_builder.py
@@ -16,6 +16,7 @@ from langchain_core.embeddings import Embeddings
from langchain_core.stores import BaseStore
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
from langchain_classic.retrievers import ParentDocumentRetriever
+from qdrant_client.http.exceptions import ResponseHandlingException
from .loaders import DocumentLoader
from .splitters import SplitterType, get_splitter, SemanticChunkerAdapter
@@ -223,18 +224,26 @@ class IndexBuilder:
async def _add_batch_with_retry(self, batch: List[Document], batch_no: int) -> None:
"""添加批次,失败时自动重试(处理网络波动)。"""
- max_retries = 3
+ max_retries = 5
+ base_delay = 2
for attempt in range(max_retries):
try:
await self.retriever.aadd_documents(batch) # type: ignore[union-attr]
+ logger.info("批次 %d 成功添加 %d 个文档", batch_no, len(batch))
return
- except (RemoteProtocolError, ConnectionError, OSError) as e:
+ except (RemoteProtocolError, ConnectionError, OSError, ResponseHandlingException) as e:
if attempt == max_retries - 1:
+ logger.error("批次 %d 重试 %d 次后仍然失败: %s", batch_no, max_retries, e)
raise
- logger.warning("批次 %d 连接断开,重试 (%d/%d): %s",
- batch_no, attempt + 1, max_retries, e)
+ wait_time = base_delay * (2 ** attempt)
+ error_type = type(e).__name__
+ logger.warning(
+ "批次 %d 遇到网络异常 [%s],%d秒后重试 (%d/%d): %s",
+ batch_no, error_type, wait_time, attempt + 1, max_retries, e
+ )
self.vector_store.refresh_client()
- await asyncio.sleep(1)
+ logger.debug("批次 %d 已刷新 Qdrant 客户端连接", batch_no)
+ await asyncio.sleep(wait_time)
# ---------- 信息获取方法 ----------
def get_collection_info(self) -> Any:
diff --git a/scripts/start.sh b/scripts/start.sh
index 6ee87f5..3243bc2 100755
--- a/scripts/start.sh
+++ b/scripts/start.sh
@@ -288,7 +288,7 @@ start_backend() {
set +a
export PYTHONPATH="$PROJECT_DIR"
- export BACKEND_PORT=8083
+ export BACKEND_PORT=8079
python app/backend.py &
BACKEND_PID=$!
echo -e "${GREEN}✓ 后端服务已启动 (PID: $BACKEND_PID)${NC}"