Compare commits

...

3 Commits

Author SHA1 Message Date
c210bcdb0b 修复长期记忆bug
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 1m16s
2026-04-20 17:30:39 +08:00
3143e0e4e6 修改引用逻辑,修改长期记忆bug 2026-04-20 15:55:58 +08:00
4e981e9dcf 文件变更 2026-04-20 14:05:57 +08:00
46 changed files with 851 additions and 668 deletions

View File

@@ -37,6 +37,9 @@ VLLM_BASE_URL=http://host.docker.internal:18000/v1
# Embedding 服务 (embeddinggemma-300M GGUF) - 端口 8082
LLAMACPP_EMBEDDING_URL=http://host.docker.internal:18001/v1
# Reranker 服务 (bge-reranker-v2-m3) - 端口 8083
LLAMACPP_RERANKER_URL=http://host.docker.internal:18002/v1
# -----------------------------------------------------------------------------
# Mem0 记忆层配置
# -----------------------------------------------------------------------------

View File

@@ -2,7 +2,7 @@
AI Agent 应用模块
"""
from .agent import AIAgentService
from .graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
from app.agent import AIAgentService
from app.graph.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
__all__ = ["AIAgentService", "AVAILABLE_TOOLS", "TOOLS_BY_NAME"]

View File

@@ -1,351 +0,0 @@
"""
AI Agent 服务类 - 支持多模型动态切换
接收外部传入的 checkpointer不负责管理连接生命周期
"""
import os
import json
from dotenv import load_dotenv
try:
from langchain_community.chat_models import ChatZhipuAI
HAS_ZHIPUAI = True
except ImportError:
HAS_ZHIPUAI = False
ChatZhipuAI = None
try:
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
HAS_OPENAI = True
except ImportError:
HAS_OPENAI = False
ChatOpenAI = None
OpenAIEmbeddings = None
from pydantic import SecretStr
try:
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
HAS_POSTGRES_CHECKPOINT = True
except ImportError:
HAS_POSTGRES_CHECKPOINT = False
AsyncPostgresSaver = None
# 本地模块
from app.graph_builder import GraphBuilder, GraphContext
from app.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
try:
from app.rag import RAGPipeline
from app.rag.tools import RAGTool
HAS_RAG = True
except ImportError as e:
HAS_RAG = False
RAGPipeline = None
RAGTool = None
from app.logger import debug, info, warning, error
load_dotenv()
class AIAgentService:
"""异步 AI Agent 服务,支持多模型动态切换,使用外部传入的 checkpointer"""
def __init__(self, checkpointer: AsyncPostgresSaver):
"""
初始化服务
Args:
checkpointer: 已经初始化的 AsyncPostgresSaver 实例
"""
self.checkpointer = checkpointer
self.graphs = {} # 存储不同模型对应的 graph 实例
self.rag = None # RAG 检索实例
self.rag_tool = None # RAG 工具实例
def _create_zhipu_llm(self):
"""创建智谱在线 LLM"""
if not HAS_ZHIPUAI:
raise ImportError("智谱AI支持不可用请安装langchain-community包")
api_key = os.getenv("ZHIPUAI_API_KEY")
if not api_key:
raise ValueError("ZHIPUAI_API_KEY not set in environment")
return ChatZhipuAI(
model="glm-4.7-flash",
api_key=api_key,
temperature=0.1,
max_tokens=4096,
timeout=120.0, # 增加请求超时时间原为60秒
max_retries=3, # 增加重试次数原为2次
streaming=True, # 确保开启流式输出
)
def _create_deepseek_llm(self):
"""创建 DeepSeek LLM使用 OpenAI 兼容 API"""
if not HAS_OPENAI:
raise ImportError("OpenAI兼容支持不可用请安装langchain-openai包")
api_key = os.getenv("DEEPSEEK_API_KEY")
if not api_key:
raise ValueError("DEEPSEEK_API_KEY not set in environment")
return ChatOpenAI(
base_url="https://api.deepseek.com",
api_key=SecretStr(api_key),
model="deepseek-reasoner", # deepseek-chat: 非思考模式, deepseek-reasoner: 思考模式
temperature=0.1,
max_tokens=4096,
timeout=60.0, # 请求超时时间(秒)
max_retries=2, # 失败后自动重试次数
streaming=True, # 确保开启流式输出
)
def _create_local_llm(self):
"""创建本地 vLLM 服务 LLM"""
if not HAS_OPENAI:
raise ImportError("OpenAI兼容支持不可用请安装langchain-openai包")
# vLLM 服务地址:优先从环境变量读取,适配 Docker、FRP 穿透和本地开发
vllm_base_url = os.getenv(
"VLLM_BASE_URL",
"http://127.0.0.1:8081/v1"
)
return ChatOpenAI(
base_url=vllm_base_url,
api_key=SecretStr(os.getenv("LLAMACPP_API_KEY", "token-abc123")),
model="gemma-4-E2B-it",
timeout=60.0, # 请求超时时间(秒)
max_retries=2, # 失败后自动重试次数
streaming=True, # 确保开启流式输出
)
def _create_embeddings(self):
"""创建嵌入模型"""
if not HAS_OPENAI:
raise ImportError("OpenAI兼容支持不可用请安装langchain-openai包")
embedding_url = os.getenv(
"LLAMACPP_EMBEDDING_URL",
"http://127.0.0.1:8082/v1"
)
return OpenAIEmbeddings(
openai_api_base=embedding_url,
openai_api_key=os.getenv("LLAMACPP_API_KEY", "token-abc123"),
model="text-embedding-ada-002", # 模型名称不重要,兼容即可
)
async def initialize(self):
"""预编译所有模型的 graph使用传入的 checkpointer"""
# 先初始化 RAG 检索系统
if HAS_RAG and RAGPipeline is not None and RAGTool is not None:
try:
info("🔄 正在初始化 RAG 检索系统...")
embeddings = self._create_embeddings()
self.rag = RAGPipeline(embeddings=embeddings)
self.rag_tool = RAGTool(self.rag).get_tool()
info("✅ RAG 检索系统初始化成功")
except Exception as e:
warning(f"⚠️ RAG 检索系统初始化失败: {e}")
self.rag = None
self.rag_tool = None
else:
info("⏭️ RAG 检索系统不可用,跳过初始化")
self.rag = None
self.rag_tool = None
model_configs = {
"local": self._create_local_llm, # 本地模型作为第一个
"deepseek": self._create_deepseek_llm, # DeepSeek 作为中间
"zhipu": self._create_zhipu_llm, # GLM-4.7 作为最后一个
}
for model_name, llm_creator in model_configs.items():
try:
info(f"🔄 正在初始化模型 '{model_name}'...")
llm = llm_creator()
# 构建工具列表:基础工具 + RAG工具如果可用
tools = AVAILABLE_TOOLS.copy()
tools_by_name = TOOLS_BY_NAME.copy()
if self.rag_tool is not None:
tools.append(self.rag_tool)
tools_by_name[self.rag_tool.name] = self.rag_tool
builder = GraphBuilder(llm, tools, tools_by_name).build()
graph = builder.compile(checkpointer=self.checkpointer)
self.graphs[model_name] = graph
info(f"✅ 模型 '{model_name}' 初始化成功")
except Exception as e:
import traceback
error_detail = traceback.format_exc()
warning(f"⚠️ 模型 '{model_name}' 初始化失败: {e}")
debug(f" 详细错误:\n{error_detail}")
if not self.graphs:
raise RuntimeError("没有可用的模型,请检查配置。可能的原因:\n"
"1. ZHIPUAI_API_KEY 未配置或无效\n"
"2. DEEPSEEK_API_KEY 未配置或无效\n"
"3. vLLM 服务未启动或地址错误 (VLLM_BASE_URL)\n"
"4. 网络连接问题")
return self
async def process_message(self, message: str, thread_id: str, model: str = "local", user_id: str = "default_user") -> dict:
"""
处理用户消息返回包含回复、token统计和耗时的字典
Returns:
dict: {
"reply": str, # AI 回复内容
"token_usage": dict, # Token 使用详情
"elapsed_time": float # 调用耗时(秒)
}
"""
# 尝试使用指定模型,如果不可用则循环尝试其他模型
if model not in self.graphs:
warning(f"警告: 模型 '{model}' 不可用,尝试切换到其他可用模型")
found = False
for available_model in self.graphs.keys():
try:
# 这里可以添加额外的模型可用性检查逻辑
model = available_model
found = True
info(f"已切换到可用模型: '{model}'")
break
except Exception as e:
warning(f"模型 '{available_model}' 也不可用: {str(e)}")
continue
if not found:
raise RuntimeError(f"错误: 没有任何可用的模型。当前注册的模型: {list(self.graphs.keys())}")
graph = self.graphs[model]
config = {
"configurable": {"thread_id": thread_id},
"metadata": {"user_id": user_id} # 写入 metadata 供历史查询使用
}
input_state = {"messages": [{"role": "user", "content": message}]}
context = GraphContext(user_id=user_id)
result = await graph.ainvoke(input_state, config=config, context=context)
reply = result["messages"][-1].content
token_usage = result.get("last_token_usage", {})
elapsed_time = result.get("last_elapsed_time", 0.0)
return {
"reply": reply,
"token_usage": token_usage,
"elapsed_time": elapsed_time
}
def _serialize_value(self, value):
"""递归将 LangChain 对象转换为可 JSON 序列化的格式"""
if hasattr(value, 'content'):
# LangChain 消息对象
msg_type = getattr(value, 'type', 'message')
return {
"role": msg_type,
"content": getattr(value, 'content', ''),
"additional_kwargs": getattr(value, 'additional_kwargs', {}),
"tool_calls": getattr(value, 'tool_calls', [])
}
elif isinstance(value, dict):
return {k: self._serialize_value(v) for k, v in value.items()}
elif isinstance(value, (list, tuple)):
return [self._serialize_value(item) for item in value]
else:
try:
json.dumps(value)
return value
except (TypeError, ValueError):
return str(value)
async def process_message_stream(self, message: str, thread_id: str, model_name: str, user_id: str = "default_user"):
"""
流式处理消息,返回异步生成器
Args:
message: 用户消息
thread_id: 线程 ID
model_name: 模型名称
user_id: 用户 ID
Yields:
字典,包含事件类型和数据
"""
graph = self.graphs.get(model_name)
if not graph:
raise ValueError(f"模型 '{model_name}' 未找到或未初始化")
config = {
"configurable": {"thread_id": thread_id},
"metadata": {"user_id": user_id}
}
input_state = {"messages": [{"role": "user", "content": message}]}
context = GraphContext(user_id=user_id)
async for chunk in graph.astream(
input_state,
config=config,
context=context,
stream_mode=["messages", "updates", "custom"], # 组合多种模式,添加 custom
version="v2", # 使用统一的v2格式
subgraphs=True # 如果你使用了子图,请开启此项
):
chunk_type = chunk["type"]
processed_event = {}
# 1. 处理 LLM Token 流 (实现打字机效果)
if chunk_type == "messages":
message_chunk, metadata = chunk["data"]
# 提取元数据
node_name = metadata.get("langgraph_node", "unknown")
# 使用 getattr 安全地获取内容,因为 message_chunk 可能不是字符串
token_content = getattr(message_chunk, 'content', str(message_chunk))
# 提取 DeepSeek reasoner 的思考过程 token
reasoning_token = ""
if hasattr(message_chunk, 'additional_kwargs'):
reasoning_token = message_chunk.additional_kwargs.get("reasoning_content", "")
# [DEBUG] 临时添加:只在 reasoning_token 不为空时打印,方便你直观地看到它
if reasoning_token:
import logging
logging.debug(f"💡 [Reasoning Token 捕获]: {repr(reasoning_token)}")
processed_event = {
"type": "llm_token",
"node": node_name,
"token": token_content,
"reasoning_token": reasoning_token,
"metadata": metadata # 可选的元数据
}
# 2. 处理状态更新 (节点执行完成)
elif chunk_type == "updates":
updates_data = chunk["data"]
# 序列化 updates 中的所有数据
serialized_data = self._serialize_value(updates_data)
processed_event = {
"type": "state_update",
"data": serialized_data
}
# 为了兼容前端旧字段,也保留 messages 字段(可选)
if "messages" in serialized_data:
processed_event["messages"] = serialized_data["messages"]
# 3. 处理自定义数据 (如果需要)
elif chunk_type == "custom":
# 自定义事件同样需要序列化
serialized_data = self._serialize_value(chunk["data"])
processed_event = {
"type": "custom",
"data": serialized_data
}
# 4. 其他类型debug, tasks等按需处理
else:
# 对于不需要的类型,直接跳过
continue
# 确保事件有数据再发送
if processed_event:
yield processed_event

7
app/agent/__init__.py Normal file
View File

@@ -0,0 +1,7 @@
"""
Agent 子模块
"""
from app.agent.service import AIAgentService
__all__ = ["AIAgentService"]

View File

@@ -3,11 +3,9 @@
利用 LangGraph checkpointer 获取对话历史和摘要
"""
from typing import List, Dict, Any, Optional
import logging
from typing import List, Dict, Any
from app.logger import error # 保持兼容,或者替换为 logger
class ThreadHistoryService:
"""线程历史查询服务"""

56
app/agent/llm_factory.py Normal file
View File

@@ -0,0 +1,56 @@
# app/llm_factory.py
import os
from langchain_community.chat_models import ChatZhipuAI
from langchain_openai import ChatOpenAI
from pydantic import SecretStr
class LLMFactory:
@staticmethod
def create_zhipu():
api_key = os.getenv("ZHIPUAI_API_KEY")
if not api_key:
raise ValueError("ZHIPUAI_API_KEY not set")
return ChatZhipuAI(
model="glm-4.7-flash",
api_key=api_key,
temperature=0.1,
max_tokens=4096,
timeout=120.0,
max_retries=3,
streaming=True,
)
@staticmethod
def create_deepseek():
api_key = os.getenv("DEEPSEEK_API_KEY")
if not api_key:
raise ValueError("DEEPSEEK_API_KEY not set")
return ChatOpenAI(
base_url="https://api.deepseek.com",
api_key=SecretStr(api_key),
model="deepseek-reasoner",
temperature=0.1,
max_tokens=4096,
timeout=60.0,
max_retries=2,
streaming=True,
)
@staticmethod
def create_local():
base_url = os.getenv("VLLM_BASE_URL", "http://127.0.0.1:8081/v1")
return ChatOpenAI(
base_url=base_url,
api_key=SecretStr(os.getenv("LLAMACPP_API_KEY", "token-abc123")),
model="gemma-4-E4B-it",
timeout=60.0,
max_retries=2,
streaming=True,
)
# 模型创建器映射
CREATORS = {
"local": create_local,
"deepseek": create_deepseek,
"zhipu": create_zhipu,
}

View File

@@ -1,18 +1,20 @@
"""
提示模板管理模块
所有系统提示和对话模板统一定义
"""
# app/prompts.py
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
def create_system_prompt(tools: list = None) -> ChatPromptTemplate:
"""
创建系统提示模板可选择动态注入工具描述
"""
tools_section = ""
if tools:
tool_descs = []
for tool in tools:
# 提取工具名称和描述的第一行
name = getattr(tool, 'name', None) or getattr(tool, '__name__', 'unknown_tool')
desc = (tool.description or "").split('\n')[0]
tool_descs.append(f"- {name}: {desc}")
tools_section = "\n".join(tool_descs)
def create_system_prompt() -> ChatPromptTemplate:
"""
创建系统提示模板
Returns:
ChatPromptTemplate: 包含系统指令和消息占位符的提示模板
"""
system_template = (
"你是一个个人生活助手和数据分析助手,请使用中文交流。\n\n"
"【用户背景信息】\n"
@@ -20,15 +22,11 @@ def create_system_prompt() -> ChatPromptTemplate:
"{memory_context}\n"
"若包含姓名、偏好等个人信息,请自然融入回应(例如称呼名字、提及偏好)。\n\n"
"【可用工具与使用规则】\n"
"- 获取温度/天气:`get_current_temperature`\n"
"- 读取文本文件:`read_local_file`(限定目录 `./user_docs`\n"
"- 读取PDF摘要`read_pdf_summary`(限定目录 `./user_docs`\n"
"- 读取Excel表格`read_excel_as_markdown`(限定目录 `./user_docs`\n"
"- 抓取网页内容:`fetch_webpage_content`\n"
f"{tools_section}\n"
"工具调用时请直接返回所需参数,无需额外说明。\n\n"
"【回答要求(必须遵守)】\n"
"1. 回答必须简洁、直接。\n"
"2. 如果你认为该问题需要进行深入的推理或思考,请务必将你的思维链或推理过程用 `<think>` 和 `</think>` 标签包裹起来,放在回答的最前面。例如:<think>这里是我的思考过程...</think>这里是最终回答。\n"
"2. 如果你认为该问题需要进行深入的推理或思考,请务必将你的思维链或推理过程用 `<think>` 和 `</think>` 标签包裹起来,放在回答的最前面。\n"
"3. 优先利用已知用户信息进行个性化回复。\n"
"4. 若无信息可依,礼貌询问或提供通用帮助。"
)
@@ -36,4 +34,4 @@ def create_system_prompt() -> ChatPromptTemplate:
return ChatPromptTemplate.from_messages([
("system", system_template),
MessagesPlaceholder(variable_name="messages")
])
])

View File

@@ -0,0 +1,23 @@
# app/rag_initializer.py
from app.rag.tools import create_rag_tool_sync
from rag_core import create_parent_retriever
from app.logger import info, warning
async def init_rag_tool(local_llm_creator):
"""初始化 RAG 工具,失败返回 None"""
try:
info("🔄 正在初始化 RAG 检索系统...")
retriever = create_parent_retriever(
collection_name="rag_documents",
search_k=5,
)
rewrite_llm = local_llm_creator()
rag_tool = create_rag_tool_sync(
retriever, rewrite_llm,
num_queries=3, rerank_top_n=5
)
info("✅ RAG 检索工具初始化成功")
return rag_tool
except Exception as e:
warning(f"⚠️ RAG 检索工具初始化失败: {e}")
return None

156
app/agent/service.py Normal file
View File

@@ -0,0 +1,156 @@
"""
AI Agent 服务类 - 支持多模型动态切换
接收外部传入的 checkpointer不负责管理连接生命周期
"""
import json
from dotenv import load_dotenv
# 本地模块
from app.graph.graph_builder import GraphBuilder, GraphContext
from app.graph.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
from app.agent.llm_factory import LLMFactory
from app.agent.rag_initializer import init_rag_tool
from app.logger import info, warning
load_dotenv()
class AIAgentService:
def __init__(self, checkpointer):
self.checkpointer = checkpointer
self.graphs = {}
self.tools = AVAILABLE_TOOLS.copy()
self.tools_by_name = TOOLS_BY_NAME.copy()
async def initialize(self):
# 1. 初始化 RAG 工具(如果需要)
rag_tool = await init_rag_tool(LLMFactory.create_local)
if rag_tool:
self.tools.append(rag_tool)
self.tools_by_name[rag_tool.name] = rag_tool
# 2. 构建各模型的 Graph
for name, creator in LLMFactory.CREATORS.items():
try:
info(f"🔄 初始化模型 '{name}'...")
llm = creator()
builder = GraphBuilder(llm, self.tools, self.tools_by_name).build()
graph = builder.compile(checkpointer=self.checkpointer)
self.graphs[name] = graph
info(f"✅ 模型 '{name}' 初始化成功")
except Exception as e:
warning(f"⚠️ 模型 '{name}' 初始化失败: {e}")
if not self.graphs:
raise RuntimeError("没有可用的模型")
return self
async def process_message(self, message: str, thread_id: str, model: str = "local", user_id: str = "default_user") -> dict:
"""处理用户消息返回包含回复、token统计和耗时的字典"""
if model not in self.graphs:
# 回退到第一个可用模型
available = list(self.graphs.keys())
if not available:
raise RuntimeError("没有可用的模型")
model = available[0]
warning(f"模型 '{model}' 不可用,已回退到 '{model}'")
graph = self.graphs[model]
config = {
"configurable": {"thread_id": thread_id},
"metadata": {"user_id": user_id}
}
input_state = {"messages": [{"role": "user", "content": message}]}
context = GraphContext(user_id=user_id)
result = await graph.ainvoke(input_state, config=config, context=context)
reply = result["messages"][-1].content
token_usage = result.get("last_token_usage", {})
elapsed_time = result.get("last_elapsed_time", 0.0)
return {
"reply": reply,
"token_usage": token_usage,
"elapsed_time": elapsed_time
}
def _serialize_value(self, value):
"""递归将 LangChain 对象转换为可 JSON 序列化的格式"""
if hasattr(value, 'content'):
msg_type = getattr(value, 'type', 'message')
return {
"role": msg_type,
"content": getattr(value, 'content', ''),
"additional_kwargs": getattr(value, 'additional_kwargs', {}),
"tool_calls": getattr(value, 'tool_calls', [])
}
elif isinstance(value, dict):
return {k: self._serialize_value(v) for k, v in value.items()}
elif isinstance(value, (list, tuple)):
return [self._serialize_value(item) for item in value]
else:
try:
json.dumps(value)
return value
except (TypeError, ValueError):
return str(value)
async def process_message_stream(self, message: str, thread_id: str, model_name: str, user_id: str = "default_user"):
"""流式处理消息,返回异步生成器"""
graph = self.graphs.get(model_name)
if not graph:
raise ValueError(f"模型 '{model_name}' 未找到或未初始化")
config = {
"configurable": {"thread_id": thread_id},
"metadata": {"user_id": user_id}
}
input_state = {"messages": [{"role": "user", "content": message}]}
context = GraphContext(user_id=user_id)
async for chunk in graph.astream(
input_state,
config=config,
context=context,
stream_mode=["messages", "updates", "custom"],
version="v2",
subgraphs=True
):
chunk_type = chunk["type"]
processed_event = {}
if chunk_type == "messages":
message_chunk, metadata = chunk["data"]
node_name = metadata.get("langgraph_node", "unknown")
token_content = getattr(message_chunk, 'content', str(message_chunk))
reasoning_token = ""
if hasattr(message_chunk, 'additional_kwargs'):
reasoning_token = message_chunk.additional_kwargs.get("reasoning_content", "")
processed_event = {
"type": "llm_token",
"node": node_name,
"token": token_content,
"reasoning_token": reasoning_token,
"metadata": metadata
}
elif chunk_type == "updates":
updates_data = chunk["data"]
serialized_data = self._serialize_value(updates_data)
processed_event = {
"type": "state_update",
"data": serialized_data
}
if "messages" in serialized_data:
processed_event["messages"] = serialized_data["messages"]
elif chunk_type == "custom":
serialized_data = self._serialize_value(chunk["data"])
processed_event = {
"type": "custom",
"data": serialized_data
}
else:
continue
if processed_event:
yield processed_event

View File

@@ -15,8 +15,8 @@ from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from app.agent import AIAgentService
from app.history import ThreadHistoryService
from app.logger import debug, info, warning, error
from app.agent.history import ThreadHistoryService
from app.logger import info, error
# 加载 .env 文件
load_dotenv()
@@ -28,7 +28,6 @@ DB_URI = os.getenv(
"postgresql://postgres:huang1998@ai-postgres:5432/langgraph_db?sslmode=disable"
)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用生命周期管理:创建并注入全局服务"""
@@ -53,7 +52,6 @@ async def lifespan(app: FastAPI):
# 5. 关闭时自动清理数据库连接async with 负责)
info("🛑 应用关闭,数据库连接池已释放")
app = FastAPI(lifespan=lifespan)
# CORS 中间件(允许前端跨域)
@@ -65,14 +63,12 @@ app.add_middleware(
allow_headers=["*"],
)
# ========== 健康检查端点 ==========
@app.get("/health")
async def health_check():
"""健康检查端点,用于 Docker 和 CI/CD 监控"""
return {"status": "ok", "service": "ai-agent-backend"}
# ========== Pydantic 模型 ==========
class ChatRequest(BaseModel):
message: str
@@ -80,7 +76,6 @@ class ChatRequest(BaseModel):
model: str = "zhipu"
user_id: str = "default_user"
class ChatResponse(BaseModel):
reply: str
thread_id: str
@@ -90,18 +85,15 @@ class ChatResponse(BaseModel):
total_tokens: int = 0
elapsed_time: float = 0.0
# ========== 依赖注入函数 ==========
def get_agent_service(request: Request) -> AIAgentService:
"""从 app.state 中获取全局 AIAgentService 实例"""
return request.app.state.agent_service
def get_history_service(request: Request) -> ThreadHistoryService:
"""从 app.state 中获取全局 ThreadHistoryService 实例"""
return request.app.state.history_service
# ========== HTTP 端点 ==========
@app.post("/chat", response_model=ChatResponse)
async def chat_endpoint(
@@ -135,7 +127,6 @@ async def chat_endpoint(
elapsed_time=elapsed_time
)
# ========== 历史查询接口 ==========
@app.get("/threads")
async def list_threads(
@@ -147,7 +138,6 @@ async def list_threads(
threads = await history_service.get_user_threads(user_id, limit)
return {"threads": threads}
@app.get("/thread/{thread_id}/messages")
async def get_thread_messages(
thread_id: str,
@@ -158,7 +148,6 @@ async def get_thread_messages(
messages = await history_service.get_thread_messages(thread_id)
return {"messages": messages}
@app.get("/thread/{thread_id}/summary")
async def get_thread_summary(
thread_id: str,
@@ -169,7 +158,6 @@ async def get_thread_summary(
summary = await history_service.get_thread_summary(thread_id)
return summary
# ========== 流式对话接口 ==========
@app.post("/chat/stream")
async def chat_stream_endpoint(
@@ -204,7 +192,6 @@ async def chat_stream_endpoint(
}
)
# ========== WebSocket 端点(可选) ==========
@app.websocket("/ws")
async def websocket_endpoint(
@@ -228,9 +215,8 @@ async def websocket_endpoint(
except WebSocketDisconnect:
pass
if __name__ == "__main__":
import uvicorn
# 使用环境变量或默认端口 8083(避免与 llama.cpp 的 8081 端口冲突)
port = int(os.getenv("BACKEND_PORT", "8083"))
# 使用环境变量或默认端口 8079(避免与 llama.cpp 的 8081 端口冲突)
port = int(os.getenv("BACKEND_PORT", "8079"))
uvicorn.run(app, host="0.0.0.0", port=port)

View File

@@ -18,6 +18,12 @@ MEMORY_SUMMARIZE_INTERVAL = int(os.getenv("MEMORY_SUMMARIZE_INTERVAL", "10"))
# Qdrant 向量数据库地址
QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333")
QDRANT_COLLECTION_NAME = os.getenv("QDRANT_COLLECTION_NAME", "mem0_user_memories")
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY", "your-qdrant-api-key")
# ========== llm 配置 ==========
# LLM 模型配置
VLLM_BASE_URL = os.getenv("VLLM_BASE_URL", "http://127.0.0.1:8081/v1")
LLM_API_KEY = os.getenv("LLM_API_KEY", "your-ai-api-key")
# llama.cpp Embedding 服务地址 (用于 Mem0 的向量化)
LLAMACPP_EMBEDDING_URL = os.getenv("LLAMACPP_EMBEDDING_URL", "http://127.0.0.1:8082/v1")

8
app/graph/__init__.py Normal file
View File

@@ -0,0 +1,8 @@
"""
Graph 子模块
"""
from app.graph.graph_builder import GraphBuilder
from app.graph.state import MessagesState, GraphContext
__all__ = ["GraphBuilder", "MessagesState", "GraphContext"]

View File

@@ -5,18 +5,17 @@ LangGraph 状态图构建模块 - 精简版,仅负责组装图
from langchain_core.language_models import BaseLLM
from langgraph.graph import StateGraph, START, END
# 本地模块
from app.state import MessagesState, GraphContext
from app.graph.state import MessagesState, GraphContext
from app.nodes import (
should_continue,
create_llm_call_node,
create_tool_call_node,
create_retrieve_memory_node,
create_summarize_node,
should_continue
finalize_node,
)
from app.nodes.memory_trigger import memory_trigger_node, set_mem0_client
from app.memory import Mem0Client
from app.nodes.finalize import finalize_node
class GraphBuilder:
@@ -45,6 +44,9 @@ class GraphBuilder:
Returns:
StateGraph 实例
"""
# 注入全局客户端
set_mem0_client(self.mem0_client)
builder = StateGraph(MessagesState, context_schema=GraphContext)
# ⭐ 通过工厂函数创建节点(依赖注入)
@@ -55,6 +57,7 @@ class GraphBuilder:
# 添加节点
builder.add_node("retrieve_memory", retrieve_memory_node)
builder.add_node("memory_trigger", memory_trigger_node)
builder.add_node("llm_call", llm_call_node)
builder.add_node("tool_node", tool_call_node)
builder.add_node("summarize", summarize_node)
@@ -62,7 +65,8 @@ class GraphBuilder:
# 添加边
builder.add_edge(START, "retrieve_memory")
builder.add_edge("retrieve_memory", "llm_call")
builder.add_edge("retrieve_memory", "memory_trigger")
builder.add_edge("memory_trigger", "llm_call")
builder.add_conditional_edges(
"llm_call",
should_continue,

View File

@@ -3,7 +3,6 @@
"""
# 标准库
import os
from pathlib import Path
# 第三方库
@@ -13,7 +12,6 @@ import requests
from bs4 import BeautifulSoup
from langchain_core.tools import tool
def _file_allow_check(filename: str) -> Path:
"""检查用户文件名是否位于允许目录 './user_docs' 下,防止路径遍历攻击。"""
allowed_dir = Path("./user_docs").resolve()
@@ -28,13 +26,11 @@ def _file_allow_check(filename: str) -> Path:
return file_path
@tool
def get_current_temperature(location: str) -> str:
"""获取指定地点的当前温度。"""
return f'当前{location}的温度为25℃'
@tool
def read_local_file(filename: str) -> str:
"""读取用户指定名称的本地文本文件内容并返回摘要。"""
@@ -46,7 +42,6 @@ def read_local_file(filename: str) -> str:
except Exception as e:
return f"读取文件时出错:{str(e)}"
@tool
def read_pdf_summary(filename: str) -> str:
"""读取PDF文件并返回内容文本摘要。"""
@@ -61,7 +56,6 @@ def read_pdf_summary(filename: str) -> str:
except Exception as e:
return f"读取PDF出错{e}"
@tool
def read_excel_as_markdown(filename: str) -> str:
"""读取Excel文件并将其主要数据转换为Markdown表格格式。"""
@@ -73,7 +67,6 @@ def read_excel_as_markdown(filename: str) -> str:
except Exception as e:
return f"读取Excel出错{e}"
@tool
def fetch_webpage_content(url: str) -> str:
"""抓取给定URL的网页正文内容并返回清晰的纯文本。"""
@@ -91,7 +84,6 @@ def fetch_webpage_content(url: str) -> str:
except Exception as e:
return f"抓取网页时出错:{str(e)}"
# 工具列表和映射(全局常量)
AVAILABLE_TOOLS = [
get_current_temperature,

View File

@@ -4,15 +4,13 @@
"""
from typing import Any, Dict
from langgraph.runtime import Runtime
# 本地模块
from app.state import MessagesState, GraphContext
from app.graph.state import MessagesState
from app.memory.mem0_client import Mem0Client
from app.utils.logging import log_state_change
from app.logger import debug
def create_retrieve_memory_node(mem0_client: Mem0Client):
"""
工厂函数创建记忆检索节点

View File

@@ -4,12 +4,11 @@ LangGraph 状态定义模块
"""
import operator
from typing import Annotated, Any
from typing import Annotated
from typing_extensions import TypedDict
from dataclasses import dataclass
from langchain_core.messages import AnyMessage
class MessagesState(TypedDict):
"""对话状态类型定义"""
messages: Annotated[list[AnyMessage], operator.add]
@@ -19,7 +18,6 @@ class MessagesState(TypedDict):
last_elapsed_time: float # 本次调用耗时(秒)
turns_since_last_summary: int # 距离上次生成摘要的轮数
@dataclass
class GraphContext:
"""图执行上下文"""

View File

@@ -1,16 +1,22 @@
from app.config import LLM_API_KEY
from app.config import VLLM_BASE_URL
import time
"""
Mem0 记忆层客户端封装模块
负责 Mem0 的初始化、检索和存储
"""
import asyncio
from typing import Optional, List, Dict, Any
from typing import Optional, List, Dict
from mem0 import AsyncMemory
from app.config import QDRANT_URL, QDRANT_COLLECTION_NAME, LLAMACPP_EMBEDDING_URL, LLAMACPP_API_KEY
from app.config import (
QDRANT_URL,QDRANT_COLLECTION_NAME,QDRANT_API_KEY,
VLLM_BASE_URL, LLM_API_KEY,
LLAMACPP_EMBEDDING_URL, LLAMACPP_API_KEY
)
from app.logger import info, warning, error
class Mem0Client:
"""Mem0 异步客户端封装类"""
@@ -37,20 +43,25 @@ class Mem0Client:
"provider": "qdrant",
"config": {
"url": QDRANT_URL, # 直接使用完整 URL
"api_key": QDRANT_API_KEY,
"collection_name": QDRANT_COLLECTION_NAME,
"embedding_model_dims": 768,
"embedding_model_dims": 1024,
}
},
"llm": {
"provider": "langchain",
"provider": "openai",
"config": {
"model": self.llm
"model": "LLM_MODEL",
"api_key": LLM_API_KEY,
"openai_base_url": VLLM_BASE_URL,
"temperature": 0.1,
"max_tokens": 2000,
}
},
"embedder": {
"provider": "openai",
"config": {
"model": "embeddinggemma-300M-Q8_0",
"model": "Qwen3-Embedding-0.6B-Q8_0",
"api_key": LLAMACPP_API_KEY,
"openai_base_url": LLAMACPP_EMBEDDING_URL,
},
@@ -118,36 +129,18 @@ class Mem0Client:
warning(f"⚠️ Mem0 检索失败: {e}")
return []
async def add_memories(self, messages: List[Dict[str, str]], user_id: str) -> bool:
"""
添加记忆(自动提取事实并存储)
Args:
messages: 消息列表,格式为 [{"role": "user/assistant/system", "content": "..."}]
user_id: 用户 ID
Returns:
bool: 是否成功
"""
if not self.mem0:
warning("⚠️ Mem0 未初始化,跳过记忆添加")
return False
try:
await asyncio.wait_for(
self.mem0.add(
messages,
user_id=user_id,
metadata={"type": "conversation"}
),
timeout=60.0
)
info("📝 [记忆添加] 已提交给 Mem0 进行事实提取")
return True
except asyncio.TimeoutError:
error("❌ Mem0 记忆添加超时 (60s)")
return False
except Exception as e:
error(f"❌ Mem0 记忆添加失败: {e}")
return False
async def add_memories(self, messages, user_id):
if not self.mem0:
return False
try:
start = time.time()
info(f"📝 开始 Mem0 add消息数: {len(messages)}")
await asyncio.wait_for(
self.mem0.add(messages, user_id=user_id, metadata={"type": "conversation"}),
timeout=60.0
)
info(f"✅ Mem0 add 完成,耗时: {time.time() - start:.2f}s")
return True
except asyncio.TimeoutError:
error(f"❌ Mem0 记忆添加超时 (60s),已等待 {time.time() - start:.2f}s")
return False

View File

@@ -5,7 +5,7 @@
from app.nodes.router import should_continue
from app.nodes.llm_call import create_llm_call_node
from app.nodes.tool_call import create_tool_call_node
from app.nodes.retrieve_memory import create_retrieve_memory_node
from app.graph.retrieve_memory import create_retrieve_memory_node
from app.nodes.summarize import create_summarize_node
from app.nodes.finalize import finalize_node

View File

@@ -4,15 +4,13 @@
"""
from typing import Any, Dict
from langgraph.runtime import Runtime
from langgraph.config import get_stream_writer
# 本地模块
from app.state import MessagesState, GraphContext
from app.graph.state import MessagesState
from app.utils.logging import log_state_change
from app.logger import info, error
from langchain_core.runnables.config import RunnableConfig
async def finalize_node(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]:

View File

@@ -3,21 +3,17 @@ LLM 调用节点模块
负责调用大语言模型并处理响应
"""
import asyncio
import time
from typing import Any, Dict
from langchain_core.language_models import BaseLLM
from langchain_core.messages import AIMessage
from langchain_core.runnables import RunnableLambda
from langgraph.runtime import Runtime
# 本地模块
from app.state import MessagesState, GraphContext
from app.prompts import create_system_prompt
from app.utils.logging import log_state_change, print_llm_input
from app.graph.state import MessagesState
from app.agent.prompts import create_system_prompt
from app.utils.logging import log_state_change
from app.logger import debug, info, error
def create_llm_call_node(llm: BaseLLM, tools: list):
"""
工厂函数:创建 LLM 调用节点
@@ -30,7 +26,7 @@ def create_llm_call_node(llm: BaseLLM, tools: list):
异步节点函数
"""
# 构建调用链
prompt = create_system_prompt()
prompt = create_system_prompt(tools)
llm_with_tools = llm.bind_tools(tools)
# 恢复带 RunnableLambda 的链,并在下方使用 astream 遍历

View File

@@ -0,0 +1,38 @@
from typing import Any, Dict
from langchain_core.runnables.config import RunnableConfig
from app.graph.state import MessagesState
from app.memory.mem0_client import Mem0Client
from app.logger import info
# 全局变量,在 GraphBuilder 中注入
_mem0_client: Mem0Client = None
def set_mem0_client(client: Mem0Client):
global _mem0_client
_mem0_client = client
async def memory_trigger_node(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]:
"""检测用户消息中的记忆指令,若命中则主动调用 Mem0 存储"""
if _mem0_client is None:
return {}
messages = state.get("messages", [])
if not messages:
return {}
last_msg = messages[-1]
content = last_msg.content if hasattr(last_msg, 'content') else str(last_msg)
# 触发词(可自行扩展)
trigger_words = ["记住", "记下", "保存", "备忘", "记录下", "别忘了"]
if any(word in content for word in trigger_words):
user_id = config.get("metadata", {}).get("user_id", "default_user")
# 确保 Mem0 已初始化
if not _mem0_client._initialized:
await _mem0_client.initialize()
# 将用户消息作为事实来源提交给 Mem0
info(f"📌 检测到记忆指令,已主动触发 Mem0 存储")
mem0_messages = [{"role": "user", "content": content}]
await _mem0_client.add_memories(mem0_messages, user_id=user_id)
return {} # 不修改状态

View File

@@ -8,7 +8,7 @@ from langchain_core.messages import AIMessage
# 本地模块
from app.config import ENABLE_GRAPH_TRACE, MEMORY_SUMMARIZE_INTERVAL
from app.state import MessagesState
from app.graph.state import MessagesState
from app.logger import info

View File

@@ -4,15 +4,13 @@
"""
from typing import Any, Dict
from langgraph.runtime import Runtime
# 本地模块
from app.state import MessagesState, GraphContext
from app.graph.state import MessagesState
from app.memory.mem0_client import Mem0Client
from app.utils.logging import log_state_change
from app.logger import debug, info, error, warning
def create_summarize_node(mem0_client: Mem0Client):
"""
工厂函数:创建记忆存储节点

View File

@@ -6,15 +6,13 @@
import asyncio
from typing import Any, Dict
from langchain_core.messages import AIMessage, ToolMessage
from langgraph.runtime import Runtime
from langgraph.config import get_stream_writer
# 本地模块
from app.state import MessagesState, GraphContext
from app.graph.state import MessagesState
from app.utils.logging import log_state_change
from app.logger import debug, info
def create_tool_call_node(tools_by_name: Dict[str, Any]):
"""
工厂函数:创建工具执行节点

View File

@@ -13,7 +13,7 @@ RAG 检索与生成模块
用户查询 → 多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档
示例用法:
>>> from app.rag import RAGPipeline, create_rag_tool
>>> from app.rag.rag import RAGPipeline, create_rag_tool
>>> from rag_indexer.builder import IndexBuilder, IndexBuilderConfig
>>> from langchain_openai import ChatOpenAI
>>>
@@ -34,16 +34,16 @@ RAG 检索与生成模块
>>> rag_tool = create_rag_tool(retriever=retriever, llm=llm)
"""
from .retriever import (
from app.rag.retriever import (
create_base_retriever,
create_hybrid_retriever,
create_qdrant_client,
)
from .reranker import CrossEncoderReranker
from .query_transform import MultiQueryGenerator
from .fusion import reciprocal_rank_fusion
from .pipeline import RAGPipeline
from .tools import create_rag_tool, create_rag_tool_sync
from app.rag.reranker import LLaMaCPPReranker
from app.rag.query_transform import MultiQueryGenerator
from app.rag.fusion import reciprocal_rank_fusion
from app.rag.pipeline import RAGPipeline
from app.rag.tools import create_rag_tool_sync
__all__ = [
@@ -53,7 +53,7 @@ __all__ = [
"create_qdrant_client",
# 重排序器
"CrossEncoderReranker",
"LLaMaCPPReranker",
# 查询改写生成器
"MultiQueryGenerator",
@@ -65,6 +65,5 @@ __all__ = [
"RAGPipeline",
# 工具创建(供 Agent 使用)
"create_rag_tool",
"create_rag_tool_sync",
]

View File

@@ -1,6 +1,6 @@
# rag/fusion.py
from typing import List, Dict, Tuple
from typing import List, Dict
from langchain_core.documents import Document
def reciprocal_rank_fusion(

View File

@@ -1,15 +1,14 @@
# rag/pipeline.py
import asyncio
from typing import List, Optional
import os
from typing import List
from langchain_core.documents import Document
from langchain_core.language_models import BaseLanguageModel
from .retriever import create_qdrant_client # 可能不需要直接使用
from .reranker import LLaMaCPPReranker
from .query_transform import MultiQueryGenerator
from .fusion import reciprocal_rank_fusion
from app.rag.reranker import LLaMaCPPReranker
from app.rag.query_transform import MultiQueryGenerator
from app.rag.fusion import reciprocal_rank_fusion
class RAGPipeline:
"""
@@ -23,7 +22,6 @@ class RAGPipeline:
llm: BaseLanguageModel,
num_queries: int = 3,
rerank_top_n: int = 5,
rerank_model: str = "BAAI/bge-reranker-base",
):
"""
Args:
@@ -41,9 +39,9 @@ class RAGPipeline:
# 初始化组件
self.query_generator = MultiQueryGenerator(llm=llm, num_queries=num_queries)
self.reranker = LLaMaCPPReranker(
base_url="http://127.0.0.1:8083",
base_url=os.getenv("LLAMACPP_RERANKER_URL", "http://127.0.0.1:8083"),
api_key=os.getenv("LLAMACPP_API_KEY", "huang1998"),
top_n=rerank_top_n,
api_key="huang1998"
)
async def aretrieve(self, query: str) -> List[Document]:
@@ -68,9 +66,9 @@ class RAGPipeline:
fused_docs = reciprocal_rank_fusion(doc_lists)
# Step 4: 重排序
if self.reranker.model is not None:
try:
final_docs = self.reranker.compress_documents(fused_docs, query)
else:
except Exception:
# 若重排序器不可用,直接返回融合后的前 N 条
final_docs = fused_docs[:self.rerank_top_n]

View File

@@ -10,24 +10,24 @@ class LLaMaCPPReranker:
"""使用远程 llama.cpp 服务对检索结果重排序。"""
def __init__(self,
base_url: str = "http://127.0.0.1:8083",
base_url: str,
api_key: str,
top_n: int = 5,
api_key: str = "huang1998", # 你设置的 LLAMA_ARG_API_KEY
timeout: int = 60):
"""
初始化远程重排序器
Args:
base_url: llama.cpp 服务的地址和端口。
base_url: llama.cpp 服务的地址和端口,默认为环境变量 LLAMACPP_RERANKER_URL 或 "http://127.0.0.1:8083"
top_n: 返回前 N 个结果。
api_key: 在容器中设置的 API 密钥。
api_key: API 密钥,默认为环境变量 LLAMACPP_API_KEY 或 "huang1998"
timeout: 请求超时时间(秒)。
"""
self.base_url = base_url.rstrip('/')
self.base_url = base_url
self.api_key = api_key
self.top_n = top_n
self.api_key = api_key
self.timeout = timeout
self.endpoint = f"{self.base_url}/v1/rerank"
self.endpoint = f"{self.base_url}/rerank"
def compress_documents(
self, documents: List[Document], query: str

View File

@@ -11,7 +11,6 @@ RAG 系统使用示例(重构版)
import asyncio
import sys
import os
from pathlib import Path
from dotenv import load_dotenv
@@ -19,12 +18,12 @@ from dotenv import load_dotenv
load_dotenv()
# 添加项目根目录到路径
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../.."))
from rag_indexer.index_builder import IndexBuilder, IndexBuilderConfig
from rag_indexer.index_builder import IndexBuilderConfig
from rag_indexer.splitters import SplitterType
from rag.pipeline import RAGPipeline
from rag.tools import create_rag_tool
from app.rag.pipeline import RAGPipeline
from app.rag.tools import create_rag_tool_sync
from pydantic import SecretStr
# 使用本地 LLM通过 OpenAI 兼容接口)
from langchain_openai import ChatOpenAI
@@ -32,7 +31,6 @@ from rag_core.retriever_factory import create_parent_retriever
load_dotenv()
def create_llm():
"""创建本地 vLLM 服务 LLM"""
vllm_base_url = os.getenv(
@@ -60,8 +58,7 @@ async def demonstrate_full_pipeline():
print("演示:固定流程 RAG 检索(多路改写 + RRF + 重排序 + 父文档)")
print("=" * 60)
retriever = retriever = create_parent_retriever(collection_name="my_docs", search_k=5)
retriever = create_parent_retriever(collection_name="rag_documents", search_k=5)
if retriever is None:
print("错误:检索器未初始化,请确保索引已构建。")
@@ -103,7 +100,6 @@ async def demonstrate_full_pipeline():
import traceback
traceback.print_exc()
async def demonstrate_tool_creation():
"""
演示创建 RAG 工具(供 Agent 使用)
@@ -119,12 +115,11 @@ async def demonstrate_tool_creation():
)
retriever = retriever = create_parent_retriever(collection_name="rag_documents", search_k=5)
# 2. 创建 LLM
llm = create_llm()
# 3. 创建工具
rag_tool = create_rag_tool(
rag_tool = create_rag_tool_sync(
retriever=retriever,
llm=llm,
num_queries=3,
@@ -136,18 +131,16 @@ async def demonstrate_tool_creation():
print(f"工具描述: {rag_tool.description[:100]}...")
# 4. 模拟 Agent 调用工具
query = "请告诉我 RAG 系统的核心组件有哪些"
query = "请告诉我 打虎英雄是谁"
print(f"\n模拟调用: {query}")
print("-" * 40)
result = await rag_tool.ainvoke({"query": query})
print(result[:800] + "..." if len(result) > 800 else result)
async def main():
await demonstrate_full_pipeline()
await demonstrate_tool_creation()
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -4,73 +4,11 @@ RAG 工具模块
将检索功能封装为 LangChain Tool供 Agent 调用。
采用固定流水线:多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档。
"""
from typing import Optional, Callable
from typing import Callable
from langchain_core.tools import tool
from langchain_core.language_models import BaseLanguageModel
from langchain_core.retrievers import BaseRetriever
from .pipeline import RAGPipeline
def create_rag_tool(
retriever: BaseRetriever,
llm: BaseLanguageModel,
num_queries: int = 3,
rerank_top_n: int = 5,
collection_name: str = "rag_documents",
) -> Callable:
"""
创建一个配置好的 RAG 检索工具(异步)。
Args:
retriever: 基础检索器(例如 ParentDocumentRetriever 实例)
llm: 用于多路查询改写的语言模型
num_queries: 生成查询变体数量
rerank_top_n: 最终返回的文档数量
collection_name: 集合名称(仅用于日志/描述)
Returns:
LangChain Tool 可调用对象(异步)
"""
# 初始化流水线(所有组件一次创建,后续复用)
pipeline = RAGPipeline(
retriever=retriever,
llm=llm,
num_queries=num_queries,
rerank_top_n=rerank_top_n,
)
@tool
async def search_knowledge_base(query: str) -> str:
"""在知识库中搜索与查询相关的文档片段。
该工具会:
1. 将用户问题改写成多个不同角度的查询
2. 并行检索每个查询的相关父文档
3. 使用倒数排名融合RRF合并结果
4. 用 Cross-Encoder 重排序模型精选最相关的片段
适用于需要精确、全面答案的事实性问题或背景知识查询。
Args:
query: 用户提出的问题或查询字符串
Returns:
格式化后的相关文档内容,若无结果则返回提示信息。
"""
try:
documents = await pipeline.aretrieve(query)
if not documents:
return f"在知识库 '{collection_name}' 中未找到与 '{query}' 相关的信息。"
context = pipeline.format_context(documents)
return context
except Exception as e:
return f"检索过程中发生错误: {str(e)}"
return search_knowledge_base
from app.rag.pipeline import RAGPipeline
def create_rag_tool_sync(
retriever: BaseRetriever,

307
app/test_backend.py Normal file
View File

@@ -0,0 +1,307 @@
#!/usr/bin/env python3
"""
完整后端测试 - 验证 Agent 所有功能
包括:短期记忆、长期记忆、工具调用、流式对话、历史查询
"""
import asyncio
import os
import sys
import uuid
from dotenv import load_dotenv
# 添加项目根目录到 Python 路径
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
load_dotenv()
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from app.agent import AIAgentService
from app.agent.history import ThreadHistoryService
from app.logger import info, warning, error
# PostgreSQL 连接字符串
DB_URI = os.getenv(
"DB_URI",
"postgresql://postgres:***@ai-postgres:5432/langgraph_db?sslmode=disable"
)
async def print_section(title):
"""打印测试区块标题"""
print("\n" + "=" * 70)
print(f" {title}")
print("=" * 70)
async def test_short_term_memory(agent_service):
"""测试短期记忆(同一 thread_id 继续对话)"""
await print_section("测试 1: 短期记忆Short-term Memory")
thread_id = str(uuid.uuid4())
user_id = "test_user_memory"
print(f"\n使用 thread_id: {thread_id[:8]}...")
print(f"使用 user_id: {user_id}")
# 第一轮对话
print("\n[第一轮] 发送消息: '我叫张三今年28岁'")
result1 = await agent_service.process_message(
"我叫张三今年28岁", thread_id, "local", user_id
)
print(f"回复: {result1['reply'][:100]}...")
# 第二轮对话 - 测试记忆
print("\n[第二轮] 发送消息: '我叫什么名字?今年多大?'")
result2 = await agent_service.process_message(
"我叫什么名字?今年多大?", thread_id, "local", user_id
)
print(f"回复: {result2['reply']}")
# 验证记忆是否存在
if "张三" in result2['reply'] or "28" in result2['reply']:
print("\n✅ 短期记忆测试通过!")
return True
else:
print("\n❌ 短期记忆测试失败!")
return False
async def test_tool_calling(agent_service):
"""测试工具调用RAG 搜索)"""
await print_section("测试 2: 工具调用Tool Calling")
thread_id = str(uuid.uuid4())
user_id = "test_user_tools"
print(f"\n使用 thread_id: {thread_id[:8]}...")
print(f"使用 user_id: {user_id}")
# 发送需要 RAG 搜索的问题
print("\n发送消息: '请告诉我,打虎英雄是谁?'")
result = await agent_service.process_message(
"请告诉我,打虎英雄是谁?", thread_id, "local", user_id
)
print(f"回复: {result['reply'][:200]}...")
# 检查是否调用了 RAG 工具(回复中会有水浒传相关内容)
if "武松" in result['reply'] or "李忠" in result['reply'] or "水浒传" in result['reply']:
print("\n✅ 工具调用测试通过!")
return True
else:
print("\n⚠️ 工具调用测试结果不确定,需要手动验证")
return None
async def test_streaming(agent_service):
"""测试流式对话"""
await print_section("测试 3: 流式对话Streaming")
thread_id = str(uuid.uuid4())
user_id = "test_user_stream"
print(f"\n使用 thread_id: {thread_id[:8]}...")
print(f"使用 user_id: {user_id}")
print("\n发送消息: '用100字介绍一下AI人工智能' (流式)...")
print("流式输出: ", end="", flush=True)
full_reply = ""
chunk_count = 0
try:
async for chunk in agent_service.process_message_stream(
"用100字介绍一下AI人工智能", thread_id, "local", user_id
):
chunk_count += 1
if chunk.get("type") == "llm_token":
token = chunk.get("token", "")
print(token, end="", flush=True)
full_reply += token
elif chunk.get("type") == "state_update":
pass # 状态更新不显示
print(f"\n\n共收到 {chunk_count} 个 chunk")
print(f"完整回复长度: {len(full_reply)}")
if chunk_count > 0 and len(full_reply) > 10:
print("\n✅ 流式对话测试通过!")
return True
else:
print("\n❌ 流式对话测试失败!")
return False
except Exception as e:
print(f"\n❌ 流式对话异常: {e}")
return False
async def test_history_service(agent_service, history_service):
"""测试历史查询服务"""
await print_section("测试 4: 历史查询服务History Service")
user_id = "test_user_history"
# 先创建几个对话
print(f"\n为 user_id={user_id} 创建测试对话...")
thread_ids = []
for i in range(3):
thread_id = str(uuid.uuid4())
thread_ids.append(thread_id)
await agent_service.process_message(
f"这是第 {i+1} 个测试对话", thread_id, "local", user_id
)
print(f" 创建线程 {i+1}: {thread_id[:8]}...")
# 1. 测试获取用户线程列表
print("\n[4.1] 测试获取用户线程列表...")
threads = await history_service.get_user_threads(user_id, limit=10)
print(f" 找到 {len(threads)} 个线程")
if len(threads) >= 3:
print(" ✅ 线程列表查询通过")
else:
print(" ⚠️ 线程数量少于预期")
# 2. 测试获取单个线程的消息历史
if thread_ids:
test_thread_id = thread_ids[0]
print(f"\n[4.2] 测试获取线程消息历史 (thread_id={test_thread_id[:8]}...)")
messages = await history_service.get_thread_messages(test_thread_id)
print(f" 找到 {len(messages)} 条消息")
if len(messages) >= 2: # 至少有一问一答
print(" ✅ 消息历史查询通过")
else:
print(" ⚠️ 消息数量少于预期")
# 3. 测试获取线程摘要
print(f"\n[4.3] 测试获取线程摘要...")
summary = await history_service.get_thread_summary(test_thread_id)
print(f" 摘要: {summary.get('summary', '')[:50]}...")
print(f" 消息数: {summary.get('message_count', 0)}")
if summary.get('message_count', 0) > 0:
print(" ✅ 线程摘要查询通过")
else:
print(" ⚠️ 摘要查询结果不确定")
return len(threads) >= 3
async def test_long_term_memory(agent_service):
"""测试长期记忆mem0"""
await print_section("测试 5: 长期记忆Long-term Memory - mem0")
thread_id1 = str(uuid.uuid4())
thread_id2 = str(uuid.uuid4()) # 不同的线程
user_id = "test_user_longterm"
print(f"\n使用 user_id: {user_id}")
print(f"线程 1: {thread_id1[:8]}...")
print(f"线程 2: {thread_id2[:8]}...")
# 在第一个线程中保存信息
print("\n[线程 1] 发送消息: '记住,我的宠物名字叫小白,是一只猫'")
result1 = await agent_service.process_message(
"记住,我的宠物名字叫小白,是一只猫", thread_id1, "local", user_id
)
print(f"回复: {result1['reply'][:100]}...")
# 等待一下,让 mem0 保存
await asyncio.sleep(1)
# 在第二个线程中询问(不同的 thread_id
print("\n[线程 2] 发送消息: '我的宠物叫什么名字?是什么动物?'")
result2 = await agent_service.process_message(
"我的宠物叫什么名字?是什么动物?", thread_id2, "local", user_id
)
print(f"回复: {result2['reply']}")
# 验证长期记忆
if "小白" in result2['reply'] or "" in result2['reply']:
print("\n✅ 长期记忆测试通过!")
return True
else:
print("\n⚠️ 长期记忆可能未启用,或需要手动验证")
return None
async def main():
"""主测试函数"""
print("\n" + "=" * 70)
print(" 后端完整功能测试")
print("=" * 70)
results = {}
try:
# 创建数据库连接和服务
print("\n正在初始化数据库连接...")
async with AsyncPostgresSaver.from_conn_string(DB_URI) as checkpointer:
await checkpointer.setup()
print("✅ 数据库连接成功")
# 创建服务实例
print("\n正在初始化 Agent 服务...")
agent_service = AIAgentService(checkpointer)
await agent_service.initialize()
print("✅ Agent 服务初始化成功")
history_service = ThreadHistoryService(checkpointer)
print("✅ 历史服务初始化成功")
print(f"\n可用模型: {list(agent_service.graphs.keys())}")
# 运行测试
results["短期记忆"] = await test_short_term_memory(agent_service)
await asyncio.sleep(1)
results["工具调用"] = await test_tool_calling(agent_service)
await asyncio.sleep(1)
results["流式对话"] = await test_streaming(agent_service)
await asyncio.sleep(1)
results["历史查询"] = await test_history_service(agent_service, history_service)
await asyncio.sleep(1)
results["长期记忆"] = await test_long_term_memory(agent_service)
await asyncio.sleep(1)
# 打印总结
await print_section("测试总结")
print("\n测试结果:")
print("-" * 40)
pass_count = 0
fail_count = 0
skip_count = 0
for test_name, result in results.items():
if result is True:
status = "✅ 通过"
pass_count += 1
elif result is False:
status = "❌ 失败"
fail_count += 1
else:
status = "⚠️ 待验证"
skip_count += 1
print(f" {test_name:12s}: {status}")
print("-" * 40)
print(f"总计: {len(results)} 个测试")
print(f"通过: {pass_count}, 失败: {fail_count}, 待验证: {skip_count}")
if fail_count == 0:
print("\n🎉 所有核心测试通过!")
else:
print(f"\n⚠️ 有 {fail_count} 个测试失败")
except Exception as e:
error(f"\n❌ 测试运行异常: {e}")
import traceback
traceback.print_exc()
return 1
return 0 if fail_count == 0 else 1
if __name__ == "__main__":
exit_code = asyncio.run(main())
sys.exit(exit_code)

View File

@@ -3,7 +3,7 @@ AI Agent 前端模块
采用分层架构设计包含配置、状态、API客户端和UI组件
"""
from .logger import debug, info, warning, error
from frontend.logger import debug, info, warning, error
__version__ = "2.0.0"
__all__ = ["debug", "info", "warning", "error"]

View File

@@ -9,8 +9,6 @@ from datetime import datetime
# 使用绝对导入
from frontend.state import AppState
from frontend.api_client import api_client
from frontend.config import config
def render_sidebar():
"""渲染左侧栏"""
@@ -25,7 +23,6 @@ def render_sidebar():
st.divider()
_render_user_section()
def _render_user_section():
"""渲染用户登录区域"""
# st.header("👤 用户") # 移除显眼的标题,改用更柔和的 caption
@@ -36,7 +33,6 @@ def _render_user_section():
else:
_render_user_info()
def _render_login_form():
"""渲染登录表单"""
username = st.text_input(
@@ -54,7 +50,6 @@ def _render_login_form():
# st.info("💡 建议登录以隔离对话历史") # 移除多余色块
def _render_user_info():
"""渲染用户信息"""
st.markdown(f"**当前用户**: `{AppState.get_user_id()}`")
@@ -64,7 +59,6 @@ def _render_user_info():
_refresh_threads()
st.rerun()
def _render_history_section():
"""渲染历史对话列表"""
col1, col2 = st.columns([3, 1])
@@ -76,7 +70,6 @@ def _render_history_section():
_render_thread_list()
def _render_history_actions():
"""渲染历史操作按钮"""
# 移除了 type="primary",让它变成普通的线框按钮,不再是大红块
@@ -84,7 +77,6 @@ def _render_history_actions():
AppState.start_new_thread()
st.rerun()
def _render_thread_list():
"""渲染线程列表"""
# 仅在初次加载时拉取,或由外部主动调用 _refresh_threads() 更新
@@ -101,7 +93,6 @@ def _render_thread_list():
for thread in threads:
_render_thread_item(thread)
def _render_thread_item(thread: dict):
"""
渲染单个线程项
@@ -130,7 +121,6 @@ def _render_thread_item(thread: dict):
):
_load_thread(thread_id)
def _format_time(time_str: str) -> str:
"""
格式化时间字符串
@@ -150,13 +140,11 @@ def _format_time(time_str: str) -> str:
except Exception:
return time_str[:10]
def _refresh_threads():
"""刷新历史线程列表"""
threads = api_client.get_user_threads(AppState.get_user_id())
AppState.set_threads(threads)
def _load_thread(thread_id: str):
"""
加载指定线程的消息历史

View File

@@ -5,6 +5,7 @@
import os
from dataclasses import dataclass
from typing import Optional
from dotenv import load_dotenv
# 加载 .env 文件
@@ -25,7 +26,7 @@ class FrontendConfig:
# ==================== 模型配置 ====================
default_model: str = "local" # 更改为local作为默认模型
model_options: dict = None
model_options: Optional[dict] = None
# ==================== 用户配置 ====================
default_user_id: str = "default_user"
@@ -53,7 +54,7 @@ class FrontendConfig:
"""从环境变量加载配置(优先级最高)"""
# API 地址(移除 /chat 后缀)
# 优先级:环境变量 API_URL > 默认值
api_url = os.getenv("API_URL", "http://127.0.0.1:8083")
api_url = os.getenv("API_URL", "http://127.0.0.1:8079")
self.api_base = api_url.replace("/chat", "").rstrip("/")

View File

@@ -7,7 +7,7 @@ import uuid
from typing import List, Dict, Any
import streamlit as st
from .config import config
from frontend.config import config
class AppState:

View File

@@ -4,10 +4,10 @@ RAG Core - 公共 RAG 组件包
提供嵌入模型、向量存储和文档存储的公共功能,被 rag_indexer 和 app/rag 共用。
"""
from .embedders import LlamaCppEmbedder
from .vector_store import QdrantVectorStore, QDRANT_URL, QDRANT_API_KEY
from .store import PostgresDocStore, create_docstore
from .retriever_factory import create_parent_retriever
from rag_core.embedders import LlamaCppEmbedder
from rag_core.vector_store import QdrantVectorStore, QDRANT_URL, QDRANT_API_KEY
from rag_core.store import PostgresDocStore, create_docstore
from rag_core.retriever_factory import create_parent_retriever
__all__ = [

View File

@@ -3,22 +3,26 @@ import os
from typing import Optional
from qdrant_client import QdrantClient
QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333")
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
def create_qdrant_client(
url: Optional[str] = None,
api_key: Optional[str] = None,
timeout: int = 120, # 索引构建需要较长超时
timeout: int = 300, # 索引构建需要较长超时
) -> QdrantClient:
effective_url = url or QDRANT_URL
effective_api_key = api_key or QDRANT_API_KEY
if not effective_url:
raise ValueError("Qdrant URL 未配置")
client_kwargs = {"url": effective_url, "timeout": timeout}
client_kwargs = {
"url": effective_url,
"timeout": timeout,
}
if effective_api_key:
client_kwargs["api_key"] = effective_api_key
return QdrantClient(**client_kwargs)

View File

@@ -5,11 +5,9 @@
import os
import httpx
from typing import List, Optional
from urllib.parse import urljoin
from langchain_core.embeddings import Embeddings
class LlamaCppEmbedder:
"""通过 OpenAI 兼容 API 封装 llama.cpp 嵌入服务。"""
@@ -17,7 +15,7 @@ class LlamaCppEmbedder:
self,
base_url: Optional[str] = None,
api_key: Optional[str] = None,
model: str = "embeddinggemma-300M-Q8_0",
model: str = "Qwen3-Embedding-0.6B-Q8_0",
):
self.base_url = base_url or os.getenv("LLAMACPP_EMBEDDING_URL", "http://127.0.0.1:8082")
self.api_key = api_key or os.getenv("LLAMACPP_API_KEY", "")
@@ -71,7 +69,6 @@ class LlamaCppEmbedder:
else:
raise ValueError(f"未知的嵌入 API 响应格式: {data}")
class _LlamaCppLangchainAdapter(Embeddings):
"""将 LlamaCppEmbedder 适配为 LangChain Embeddings 接口。"""

View File

@@ -2,14 +2,7 @@
from langchain_core.embeddings import Embeddings
from langchain_classic.retrievers import ParentDocumentRetriever
from langchain_text_splitters import RecursiveCharacterTextSplitter
from rag_indexer.splitters import SplitterType, get_splitter
import asyncio
import logging
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Union, Optional, Any, Dict, Tuple
from httpx import RemoteProtocolError
from langchain_core.documents import Document
from typing import Optional
from langchain_core.embeddings import Embeddings
from langchain_core.stores import BaseStore
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
@@ -17,7 +10,6 @@ from langchain_classic.retrievers import ParentDocumentRetriever
from rag_core import LlamaCppEmbedder, QdrantVectorStore, create_docstore
def create_parent_retriever(
collection_name: str = "rag_documents",
embeddings: Optional[Embeddings] = None,

View File

@@ -15,8 +15,8 @@
"""
from .postgres import PostgresDocStore
from .factory import create_docstore, get_docstore_uri, DEFAULT_DB_URI
from rag_core.store.postgres import PostgresDocStore
from rag_core.store.factory import create_docstore, get_docstore_uri, DEFAULT_DB_URI
__version__ = "2.0.0"

View File

@@ -9,9 +9,11 @@ import logging
from typing import Optional, Tuple
from langchain_core.stores import BaseStore
from .postgres import PostgresDocStore
from rag_core.store.postgres import PostgresDocStore
from dotenv import load_dotenv
load_dotenv()
logger = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
# 默认连接字符串(从环境变量读取)
DEFAULT_DB_URI = os.getenv(

View File

@@ -4,12 +4,10 @@
使用 asyncpg 实现真正的异步 PostgreSQL 文档存储,支持高并发访问。
"""
from __future__ import annotations
import asyncio
import json
import logging
from typing import List, Dict, Any, Optional, Iterator, Tuple, Sequence, cast
from typing import List, Dict, Any, Optional, Iterator, Tuple, Sequence
from langchain_core.documents import Document
from langchain_core.stores import BaseStore
@@ -18,7 +16,6 @@ import asyncpg
logger = logging.getLogger(__name__)
class PostgresDocStore(BaseStore[str, Any]):
"""
异步 PostgreSQL 文档存储实现。

View File

@@ -4,13 +4,16 @@ Qdrant 向量数据库包装器。
import logging
import os
import time
from typing import List, Optional, Dict, Any
from langchain_core.documents import Document
from langchain_qdrant import QdrantVectorStore as LangchainQdrantVS
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams
from .client import create_qdrant_client
from httpx import RemoteProtocolError
from qdrant_client.http.exceptions import ResponseHandlingException
from rag_core.client import create_qdrant_client
logger = logging.getLogger(__name__)
@@ -28,9 +31,11 @@ class QdrantVectorStore:
):
self.collection_name = collection_name
self._client: Optional[QdrantClient] = None
self._connection_attempts = 0
self._last_connection_time: Optional[float] = None
if embeddings is None:
from .embedders import LlamaCppEmbedder
from rag_core.embedders import LlamaCppEmbedder
embedder = LlamaCppEmbedder()
self.embeddings = embedder.as_langchain_embeddings()
else:
@@ -46,38 +51,89 @@ class QdrantVectorStore:
def get_client(self) -> QdrantClient:
if self._client is None:
self._client = create_qdrant_client(timeout=120)
self._client = create_qdrant_client(timeout=300)
self._connection_attempts += 1
self._last_connection_time = time.time()
logger.debug("Qdrant 客户端已创建 (第 %d 次连接)", self._connection_attempts)
return self._client
def refresh_client(self):
"""关闭旧连接,创建新连接。"""
if self._client is not None:
self._client.close()
self._client = None
try:
self._client.close()
logger.debug("Qdrant 旧连接已关闭")
except Exception as e:
logger.warning("关闭 Qdrant 连接时出现异常: %s", e)
finally:
self._client = None
self._last_connection_time = None
def check_connection_health(self) -> bool:
"""检查连接健康状态,如果连接已失效则自动重建。"""
if self._client is None:
logger.info("Qdrant 客户端未初始化,将创建新连接")
return False
try:
client = self.get_client()
client.get_collections()
logger.debug("Qdrant 连接健康检查通过")
return True
except (RemoteProtocolError, ConnectionError, OSError, ResponseHandlingException) as e:
logger.warning("Qdrant 连接健康检查失败: %s", e)
self.refresh_client()
return False
def get_connection_stats(self) -> Dict[str, Any]:
"""获取连接统计信息。"""
return {
"connection_attempts": self._connection_attempts,
"last_connection_time": self._last_connection_time,
"client_initialized": self._client is not None,
}
def create_collection(self, vector_size: Optional[int] = None, force_recreate: bool = False):
"""创建集合,设置合适的向量维度。"""
if vector_size is None:
from .embedders import LlamaCppEmbedder
from rag_core.embedders import LlamaCppEmbedder
embedder = LlamaCppEmbedder()
vector_size = embedder.get_embedding_dimension()
client = self.get_client()
collections = client.get_collections().collections
exists = any(c.name == self.collection_name for c in collections)
max_retries = 3
base_delay = 2
for attempt in range(max_retries):
try:
client = self.get_client()
collections = client.get_collections().collections
exists = any(c.name == self.collection_name for c in collections)
if exists and force_recreate:
client.delete_collection(self.collection_name)
exists = False
if exists and force_recreate:
client.delete_collection(self.collection_name)
exists = False
if not exists:
client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE),
)
logger.info("集合 '%s' 已创建(维度=%d", self.collection_name, vector_size)
else:
logger.info("集合 '%s' 已存在", self.collection_name)
if not exists:
client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE),
)
logger.info("集合 '%s' 已创建(维度=%d", self.collection_name, vector_size)
else:
logger.info("集合 '%s' 已存在", self.collection_name)
return
except (RemoteProtocolError, ConnectionError, OSError, ResponseHandlingException) as e:
if attempt == max_retries - 1:
logger.error("创建集合 '%s' 重试 %d 次后仍然失败: %s", self.collection_name, max_retries, e)
raise
wait_time = base_delay * (2 ** attempt)
error_type = type(e).__name__
logger.warning(
"创建集合 '%s' 遇到网络异常 [%s]%d秒后重试 (%d/%d): %s",
self.collection_name, error_type, wait_time, attempt + 1, max_retries, e
)
self.refresh_client()
logger.debug("已刷新 Qdrant 客户端连接")
time.sleep(wait_time)
def add_documents(self, documents: List[Document], batch_size: int = 100):
"""将文档添加到向量数据库。"""
@@ -102,9 +158,10 @@ class QdrantVectorStore:
info = self.get_client().get_collection(self.collection_name)
vectors_config = info.config.params.vectors
if isinstance(vectors_config, dict):
vector_size = next(iter(vectors_config.values())).size
first_config = next(iter(vectors_config.values()), None)
vector_size = first_config.size if first_config else 0
else:
vector_size = vectors_config.size
vector_size = vectors_config.size if vectors_config else 0
return {
"name": self.collection_name,
"vectors_count": info.points_count or 0,

View File

@@ -23,9 +23,9 @@ Offline RAG Indexer module.
>>> await builder.build_from_file("document.pdf")
"""
from .index_builder import IndexBuilder, IndexBuilderConfig, DocstoreConfig
from .loaders import DocumentLoader
from .splitters import SplitterType, get_splitter
from rag_indexer.index_builder import IndexBuilder, IndexBuilderConfig, DocstoreConfig
from rag_indexer.loaders import DocumentLoader
from rag_indexer.splitters import SplitterType, get_splitter
# 从 rag_core 重新导出常用组件
from rag_core import (

View File

@@ -8,23 +8,21 @@ import asyncio
import logging
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Union, Optional, Any, Dict, Tuple
from typing import List, Union, Optional, Any, Dict
from httpx import RemoteProtocolError
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.stores import BaseStore
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
from langchain_classic.retrievers import ParentDocumentRetriever
from qdrant_client.http.exceptions import ResponseHandlingException
from .loaders import DocumentLoader
from .splitters import SplitterType, get_splitter, SemanticChunkerAdapter
from rag_indexer.loaders import DocumentLoader
from rag_indexer.splitters import SplitterType, get_splitter
from rag_core import LlamaCppEmbedder, QdrantVectorStore, create_docstore, create_parent_retriever
logger = logging.getLogger(__name__)
# ---------- 配置数据类 ----------
@dataclass
class DocstoreConfig:
@@ -35,7 +33,6 @@ class DocstoreConfig:
# 若要从外部注入已创建好的 docstore可直接设置此字段
instance: Optional[BaseStore] = None
@dataclass
class IndexBuilderConfig:
"""索引构建器配置。"""
@@ -59,7 +56,6 @@ class IndexBuilderConfig:
# 其他切分器参数(当 splitter_type 非父子块时使用)
extra_splitter_kwargs: Dict[str, Any] = field(default_factory=dict)
# ---------- 索引构建器 ----------
class IndexBuilder:
"""RAG 索引构建主流水线,支持单块切分与父子块切分。"""
@@ -223,18 +219,26 @@ class IndexBuilder:
async def _add_batch_with_retry(self, batch: List[Document], batch_no: int) -> None:
"""添加批次,失败时自动重试(处理网络波动)。"""
max_retries = 3
max_retries = 5
base_delay = 2
for attempt in range(max_retries):
try:
await self.retriever.aadd_documents(batch) # type: ignore[union-attr]
logger.info("批次 %d 成功添加 %d 个文档", batch_no, len(batch))
return
except (RemoteProtocolError, ConnectionError, OSError) as e:
except (RemoteProtocolError, ConnectionError, OSError, ResponseHandlingException) as e:
if attempt == max_retries - 1:
logger.error("批次 %d 重试 %d 次后仍然失败: %s", batch_no, max_retries, e)
raise
logger.warning("批次 %d 连接断开,重试 (%d/%d): %s",
batch_no, attempt + 1, max_retries, e)
wait_time = base_delay * (2 ** attempt)
error_type = type(e).__name__
logger.warning(
"批次 %d 遇到网络异常 [%s]%d秒后重试 (%d/%d): %s",
batch_no, error_type, wait_time, attempt + 1, max_retries, e
)
self.vector_store.refresh_client()
await asyncio.sleep(1)
logger.debug("批次 %d 已刷新 Qdrant 客户端连接", batch_no)
await asyncio.sleep(wait_time)
# ---------- 信息获取方法 ----------
def get_collection_info(self) -> Any:

View File

@@ -250,7 +250,7 @@ start_embedding() {
echo -e "${BLUE}🚀 启动 llama.cpp Embedding 容器...${NC}"
# 检查模型文件
if [ ! -f "/home/huang/Study/AIModel/GGUF/embeddinggemma-300M-Q8_0.gguf" ]; then
if [ ! -f "/home/huang/Study/AIModel/GGUF/Qwen3-Embedding-0.6B-Q8_0.gguf" ]; then
echo -e "${RED}✗ 错误Embedding 模型文件不存在${NC}"
exit 1
fi
@@ -263,13 +263,16 @@ start_embedding() {
--device=/dev/dri \
-v /home/huang/Study/AIModel/GGUF:/models \
-p 8082:8080 \
-e LLAMA_ARG_CTX_SIZE=16384 \
-e LLAMA_ARG_N_PARALLEL=1 \
-e LLAMA_ARG_BATCH=512 \
-e LLAMA_ARG_N_GPU_LAYERS=99 \
-e LLAMA_ARG_API_KEY=huang1998 \
ghcr.io/ggml-org/llama.cpp:server-rocm \
-m /models/embeddinggemma-300M-Q8_0.gguf \
-m /models/Qwen3-Embedding-0.6B-Q8_0.gguf \
--host 0.0.0.0 \
--port 8080 \
-ngl 99 \
--embeddings \
-c 512
--embeddings
echo -e "${GREEN}✓ llama.cpp Embedding 容器已启动 (端口 8082)${NC}"
sleep 5
@@ -288,7 +291,7 @@ start_backend() {
set +a
export PYTHONPATH="$PROJECT_DIR"
export BACKEND_PORT=8083
export BACKEND_PORT=8079
python app/backend.py &
BACKEND_PID=$!
echo -e "${GREEN}✓ 后端服务已启动 (PID: $BACKEND_PID)${NC}"