Compare commits
3 Commits
3c906e91d9
...
c210bcdb0b
| Author | SHA1 | Date | |
|---|---|---|---|
| c210bcdb0b | |||
| 3143e0e4e6 | |||
| 4e981e9dcf |
@@ -37,6 +37,9 @@ VLLM_BASE_URL=http://host.docker.internal:18000/v1
|
||||
# Embedding 服务 (embeddinggemma-300M GGUF) - 端口 8082
|
||||
LLAMACPP_EMBEDDING_URL=http://host.docker.internal:18001/v1
|
||||
|
||||
# Reranker 服务 (bge-reranker-v2-m3) - 端口 8083
|
||||
LLAMACPP_RERANKER_URL=http://host.docker.internal:18002/v1
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Mem0 记忆层配置
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
AI Agent 应用模块
|
||||
"""
|
||||
|
||||
from .agent import AIAgentService
|
||||
from .graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
|
||||
from app.agent import AIAgentService
|
||||
from app.graph.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
|
||||
|
||||
__all__ = ["AIAgentService", "AVAILABLE_TOOLS", "TOOLS_BY_NAME"]
|
||||
|
||||
351
app/agent.py
351
app/agent.py
@@ -1,351 +0,0 @@
|
||||
"""
|
||||
AI Agent 服务类 - 支持多模型动态切换
|
||||
接收外部传入的 checkpointer,不负责管理连接生命周期
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
from dotenv import load_dotenv
|
||||
try:
|
||||
from langchain_community.chat_models import ChatZhipuAI
|
||||
HAS_ZHIPUAI = True
|
||||
except ImportError:
|
||||
HAS_ZHIPUAI = False
|
||||
ChatZhipuAI = None
|
||||
|
||||
try:
|
||||
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
||||
HAS_OPENAI = True
|
||||
except ImportError:
|
||||
HAS_OPENAI = False
|
||||
ChatOpenAI = None
|
||||
OpenAIEmbeddings = None
|
||||
|
||||
from pydantic import SecretStr
|
||||
try:
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
HAS_POSTGRES_CHECKPOINT = True
|
||||
except ImportError:
|
||||
HAS_POSTGRES_CHECKPOINT = False
|
||||
AsyncPostgresSaver = None
|
||||
|
||||
# 本地模块
|
||||
from app.graph_builder import GraphBuilder, GraphContext
|
||||
from app.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
|
||||
try:
|
||||
from app.rag import RAGPipeline
|
||||
from app.rag.tools import RAGTool
|
||||
HAS_RAG = True
|
||||
except ImportError as e:
|
||||
HAS_RAG = False
|
||||
RAGPipeline = None
|
||||
RAGTool = None
|
||||
from app.logger import debug, info, warning, error
|
||||
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class AIAgentService:
|
||||
"""异步 AI Agent 服务,支持多模型动态切换,使用外部传入的 checkpointer"""
|
||||
|
||||
def __init__(self, checkpointer: AsyncPostgresSaver):
|
||||
"""
|
||||
初始化服务
|
||||
Args:
|
||||
checkpointer: 已经初始化的 AsyncPostgresSaver 实例
|
||||
"""
|
||||
self.checkpointer = checkpointer
|
||||
self.graphs = {} # 存储不同模型对应的 graph 实例
|
||||
self.rag = None # RAG 检索实例
|
||||
self.rag_tool = None # RAG 工具实例
|
||||
|
||||
def _create_zhipu_llm(self):
|
||||
"""创建智谱在线 LLM"""
|
||||
if not HAS_ZHIPUAI:
|
||||
raise ImportError("智谱AI支持不可用,请安装langchain-community包")
|
||||
api_key = os.getenv("ZHIPUAI_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError("ZHIPUAI_API_KEY not set in environment")
|
||||
return ChatZhipuAI(
|
||||
model="glm-4.7-flash",
|
||||
api_key=api_key,
|
||||
temperature=0.1,
|
||||
max_tokens=4096,
|
||||
timeout=120.0, # 增加请求超时时间(秒),原为60秒
|
||||
max_retries=3, # 增加重试次数,原为2次
|
||||
streaming=True, # 确保开启流式输出
|
||||
)
|
||||
|
||||
def _create_deepseek_llm(self):
|
||||
"""创建 DeepSeek LLM(使用 OpenAI 兼容 API)"""
|
||||
if not HAS_OPENAI:
|
||||
raise ImportError("OpenAI兼容支持不可用,请安装langchain-openai包")
|
||||
api_key = os.getenv("DEEPSEEK_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError("DEEPSEEK_API_KEY not set in environment")
|
||||
return ChatOpenAI(
|
||||
base_url="https://api.deepseek.com",
|
||||
api_key=SecretStr(api_key),
|
||||
model="deepseek-reasoner", # deepseek-chat: 非思考模式, deepseek-reasoner: 思考模式
|
||||
temperature=0.1,
|
||||
max_tokens=4096,
|
||||
timeout=60.0, # 请求超时时间(秒)
|
||||
max_retries=2, # 失败后自动重试次数
|
||||
streaming=True, # 确保开启流式输出
|
||||
)
|
||||
|
||||
def _create_local_llm(self):
|
||||
"""创建本地 vLLM 服务 LLM"""
|
||||
if not HAS_OPENAI:
|
||||
raise ImportError("OpenAI兼容支持不可用,请安装langchain-openai包")
|
||||
# vLLM 服务地址:优先从环境变量读取,适配 Docker、FRP 穿透和本地开发
|
||||
vllm_base_url = os.getenv(
|
||||
"VLLM_BASE_URL",
|
||||
"http://127.0.0.1:8081/v1"
|
||||
)
|
||||
|
||||
return ChatOpenAI(
|
||||
base_url=vllm_base_url,
|
||||
api_key=SecretStr(os.getenv("LLAMACPP_API_KEY", "token-abc123")),
|
||||
model="gemma-4-E2B-it",
|
||||
timeout=60.0, # 请求超时时间(秒)
|
||||
max_retries=2, # 失败后自动重试次数
|
||||
streaming=True, # 确保开启流式输出
|
||||
)
|
||||
|
||||
def _create_embeddings(self):
|
||||
"""创建嵌入模型"""
|
||||
if not HAS_OPENAI:
|
||||
raise ImportError("OpenAI兼容支持不可用,请安装langchain-openai包")
|
||||
embedding_url = os.getenv(
|
||||
"LLAMACPP_EMBEDDING_URL",
|
||||
"http://127.0.0.1:8082/v1"
|
||||
)
|
||||
return OpenAIEmbeddings(
|
||||
openai_api_base=embedding_url,
|
||||
openai_api_key=os.getenv("LLAMACPP_API_KEY", "token-abc123"),
|
||||
model="text-embedding-ada-002", # 模型名称不重要,兼容即可
|
||||
)
|
||||
|
||||
async def initialize(self):
|
||||
"""预编译所有模型的 graph(使用传入的 checkpointer)"""
|
||||
# 先初始化 RAG 检索系统
|
||||
if HAS_RAG and RAGPipeline is not None and RAGTool is not None:
|
||||
try:
|
||||
info("🔄 正在初始化 RAG 检索系统...")
|
||||
embeddings = self._create_embeddings()
|
||||
self.rag = RAGPipeline(embeddings=embeddings)
|
||||
self.rag_tool = RAGTool(self.rag).get_tool()
|
||||
info("✅ RAG 检索系统初始化成功")
|
||||
except Exception as e:
|
||||
warning(f"⚠️ RAG 检索系统初始化失败: {e}")
|
||||
self.rag = None
|
||||
self.rag_tool = None
|
||||
else:
|
||||
info("⏭️ RAG 检索系统不可用,跳过初始化")
|
||||
self.rag = None
|
||||
self.rag_tool = None
|
||||
|
||||
model_configs = {
|
||||
"local": self._create_local_llm, # 本地模型作为第一个
|
||||
"deepseek": self._create_deepseek_llm, # DeepSeek 作为中间
|
||||
"zhipu": self._create_zhipu_llm, # GLM-4.7 作为最后一个
|
||||
}
|
||||
|
||||
for model_name, llm_creator in model_configs.items():
|
||||
try:
|
||||
info(f"🔄 正在初始化模型 '{model_name}'...")
|
||||
llm = llm_creator()
|
||||
|
||||
# 构建工具列表:基础工具 + RAG工具(如果可用)
|
||||
tools = AVAILABLE_TOOLS.copy()
|
||||
tools_by_name = TOOLS_BY_NAME.copy()
|
||||
|
||||
if self.rag_tool is not None:
|
||||
tools.append(self.rag_tool)
|
||||
tools_by_name[self.rag_tool.name] = self.rag_tool
|
||||
|
||||
builder = GraphBuilder(llm, tools, tools_by_name).build()
|
||||
graph = builder.compile(checkpointer=self.checkpointer)
|
||||
self.graphs[model_name] = graph
|
||||
info(f"✅ 模型 '{model_name}' 初始化成功")
|
||||
except Exception as e:
|
||||
import traceback
|
||||
error_detail = traceback.format_exc()
|
||||
warning(f"⚠️ 模型 '{model_name}' 初始化失败: {e}")
|
||||
debug(f" 详细错误:\n{error_detail}")
|
||||
|
||||
if not self.graphs:
|
||||
raise RuntimeError("没有可用的模型,请检查配置。可能的原因:\n"
|
||||
"1. ZHIPUAI_API_KEY 未配置或无效\n"
|
||||
"2. DEEPSEEK_API_KEY 未配置或无效\n"
|
||||
"3. vLLM 服务未启动或地址错误 (VLLM_BASE_URL)\n"
|
||||
"4. 网络连接问题")
|
||||
|
||||
return self
|
||||
|
||||
async def process_message(self, message: str, thread_id: str, model: str = "local", user_id: str = "default_user") -> dict:
|
||||
"""
|
||||
处理用户消息,返回包含回复、token统计和耗时的字典
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
"reply": str, # AI 回复内容
|
||||
"token_usage": dict, # Token 使用详情
|
||||
"elapsed_time": float # 调用耗时(秒)
|
||||
}
|
||||
"""
|
||||
# 尝试使用指定模型,如果不可用则循环尝试其他模型
|
||||
if model not in self.graphs:
|
||||
warning(f"警告: 模型 '{model}' 不可用,尝试切换到其他可用模型")
|
||||
found = False
|
||||
for available_model in self.graphs.keys():
|
||||
try:
|
||||
# 这里可以添加额外的模型可用性检查逻辑
|
||||
model = available_model
|
||||
found = True
|
||||
info(f"已切换到可用模型: '{model}'")
|
||||
break
|
||||
except Exception as e:
|
||||
warning(f"模型 '{available_model}' 也不可用: {str(e)}")
|
||||
continue
|
||||
|
||||
if not found:
|
||||
raise RuntimeError(f"错误: 没有任何可用的模型。当前注册的模型: {list(self.graphs.keys())}")
|
||||
|
||||
graph = self.graphs[model]
|
||||
config = {
|
||||
"configurable": {"thread_id": thread_id},
|
||||
"metadata": {"user_id": user_id} # 写入 metadata 供历史查询使用
|
||||
}
|
||||
input_state = {"messages": [{"role": "user", "content": message}]}
|
||||
context = GraphContext(user_id=user_id)
|
||||
|
||||
result = await graph.ainvoke(input_state, config=config, context=context)
|
||||
|
||||
reply = result["messages"][-1].content
|
||||
token_usage = result.get("last_token_usage", {})
|
||||
elapsed_time = result.get("last_elapsed_time", 0.0)
|
||||
|
||||
return {
|
||||
"reply": reply,
|
||||
"token_usage": token_usage,
|
||||
"elapsed_time": elapsed_time
|
||||
}
|
||||
|
||||
def _serialize_value(self, value):
|
||||
"""递归将 LangChain 对象转换为可 JSON 序列化的格式"""
|
||||
if hasattr(value, 'content'):
|
||||
# LangChain 消息对象
|
||||
msg_type = getattr(value, 'type', 'message')
|
||||
return {
|
||||
"role": msg_type,
|
||||
"content": getattr(value, 'content', ''),
|
||||
"additional_kwargs": getattr(value, 'additional_kwargs', {}),
|
||||
"tool_calls": getattr(value, 'tool_calls', [])
|
||||
}
|
||||
elif isinstance(value, dict):
|
||||
return {k: self._serialize_value(v) for k, v in value.items()}
|
||||
elif isinstance(value, (list, tuple)):
|
||||
return [self._serialize_value(item) for item in value]
|
||||
else:
|
||||
try:
|
||||
json.dumps(value)
|
||||
return value
|
||||
except (TypeError, ValueError):
|
||||
return str(value)
|
||||
|
||||
async def process_message_stream(self, message: str, thread_id: str, model_name: str, user_id: str = "default_user"):
|
||||
"""
|
||||
流式处理消息,返回异步生成器
|
||||
|
||||
Args:
|
||||
message: 用户消息
|
||||
thread_id: 线程 ID
|
||||
model_name: 模型名称
|
||||
user_id: 用户 ID
|
||||
|
||||
Yields:
|
||||
字典,包含事件类型和数据
|
||||
"""
|
||||
graph = self.graphs.get(model_name)
|
||||
|
||||
if not graph:
|
||||
raise ValueError(f"模型 '{model_name}' 未找到或未初始化")
|
||||
|
||||
config = {
|
||||
"configurable": {"thread_id": thread_id},
|
||||
"metadata": {"user_id": user_id}
|
||||
}
|
||||
input_state = {"messages": [{"role": "user", "content": message}]}
|
||||
context = GraphContext(user_id=user_id)
|
||||
|
||||
async for chunk in graph.astream(
|
||||
input_state,
|
||||
config=config,
|
||||
context=context,
|
||||
stream_mode=["messages", "updates", "custom"], # 组合多种模式,添加 custom
|
||||
version="v2", # 使用统一的v2格式
|
||||
subgraphs=True # 如果你使用了子图,请开启此项
|
||||
):
|
||||
chunk_type = chunk["type"]
|
||||
processed_event = {}
|
||||
|
||||
# 1. 处理 LLM Token 流 (实现打字机效果)
|
||||
if chunk_type == "messages":
|
||||
message_chunk, metadata = chunk["data"]
|
||||
|
||||
# 提取元数据
|
||||
node_name = metadata.get("langgraph_node", "unknown")
|
||||
# 使用 getattr 安全地获取内容,因为 message_chunk 可能不是字符串
|
||||
token_content = getattr(message_chunk, 'content', str(message_chunk))
|
||||
|
||||
# 提取 DeepSeek reasoner 的思考过程 token
|
||||
reasoning_token = ""
|
||||
if hasattr(message_chunk, 'additional_kwargs'):
|
||||
reasoning_token = message_chunk.additional_kwargs.get("reasoning_content", "")
|
||||
|
||||
# [DEBUG] 临时添加:只在 reasoning_token 不为空时打印,方便你直观地看到它
|
||||
if reasoning_token:
|
||||
import logging
|
||||
logging.debug(f"💡 [Reasoning Token 捕获]: {repr(reasoning_token)}")
|
||||
|
||||
processed_event = {
|
||||
"type": "llm_token",
|
||||
"node": node_name,
|
||||
"token": token_content,
|
||||
"reasoning_token": reasoning_token,
|
||||
"metadata": metadata # 可选的元数据
|
||||
}
|
||||
|
||||
# 2. 处理状态更新 (节点执行完成)
|
||||
elif chunk_type == "updates":
|
||||
updates_data = chunk["data"]
|
||||
# 序列化 updates 中的所有数据
|
||||
serialized_data = self._serialize_value(updates_data)
|
||||
processed_event = {
|
||||
"type": "state_update",
|
||||
"data": serialized_data
|
||||
}
|
||||
# 为了兼容前端旧字段,也保留 messages 字段(可选)
|
||||
if "messages" in serialized_data:
|
||||
processed_event["messages"] = serialized_data["messages"]
|
||||
|
||||
# 3. 处理自定义数据 (如果需要)
|
||||
elif chunk_type == "custom":
|
||||
# 自定义事件同样需要序列化
|
||||
serialized_data = self._serialize_value(chunk["data"])
|
||||
processed_event = {
|
||||
"type": "custom",
|
||||
"data": serialized_data
|
||||
}
|
||||
|
||||
# 4. 其他类型(debug, tasks等)按需处理
|
||||
else:
|
||||
# 对于不需要的类型,直接跳过
|
||||
continue
|
||||
|
||||
# 确保事件有数据再发送
|
||||
if processed_event:
|
||||
yield processed_event
|
||||
7
app/agent/__init__.py
Normal file
7
app/agent/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
Agent 子模块
|
||||
"""
|
||||
|
||||
from app.agent.service import AIAgentService
|
||||
|
||||
__all__ = ["AIAgentService"]
|
||||
@@ -3,11 +3,9 @@
|
||||
利用 LangGraph 的 checkpointer 获取对话历史和摘要
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Optional
|
||||
import logging
|
||||
from typing import List, Dict, Any
|
||||
from app.logger import error # 保持兼容,或者替换为 logger
|
||||
|
||||
|
||||
class ThreadHistoryService:
|
||||
"""线程历史查询服务"""
|
||||
|
||||
56
app/agent/llm_factory.py
Normal file
56
app/agent/llm_factory.py
Normal file
@@ -0,0 +1,56 @@
|
||||
# app/llm_factory.py
|
||||
import os
|
||||
from langchain_community.chat_models import ChatZhipuAI
|
||||
from langchain_openai import ChatOpenAI
|
||||
from pydantic import SecretStr
|
||||
|
||||
class LLMFactory:
|
||||
@staticmethod
|
||||
def create_zhipu():
|
||||
api_key = os.getenv("ZHIPUAI_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError("ZHIPUAI_API_KEY not set")
|
||||
return ChatZhipuAI(
|
||||
model="glm-4.7-flash",
|
||||
api_key=api_key,
|
||||
temperature=0.1,
|
||||
max_tokens=4096,
|
||||
timeout=120.0,
|
||||
max_retries=3,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_deepseek():
|
||||
api_key = os.getenv("DEEPSEEK_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError("DEEPSEEK_API_KEY not set")
|
||||
return ChatOpenAI(
|
||||
base_url="https://api.deepseek.com",
|
||||
api_key=SecretStr(api_key),
|
||||
model="deepseek-reasoner",
|
||||
temperature=0.1,
|
||||
max_tokens=4096,
|
||||
timeout=60.0,
|
||||
max_retries=2,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_local():
|
||||
base_url = os.getenv("VLLM_BASE_URL", "http://127.0.0.1:8081/v1")
|
||||
return ChatOpenAI(
|
||||
base_url=base_url,
|
||||
api_key=SecretStr(os.getenv("LLAMACPP_API_KEY", "token-abc123")),
|
||||
model="gemma-4-E4B-it",
|
||||
timeout=60.0,
|
||||
max_retries=2,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
# 模型创建器映射
|
||||
CREATORS = {
|
||||
"local": create_local,
|
||||
"deepseek": create_deepseek,
|
||||
"zhipu": create_zhipu,
|
||||
}
|
||||
@@ -1,18 +1,20 @@
|
||||
"""
|
||||
提示模板管理模块
|
||||
所有系统提示和对话模板统一定义
|
||||
"""
|
||||
|
||||
# app/prompts.py
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
|
||||
def create_system_prompt(tools: list = None) -> ChatPromptTemplate:
|
||||
"""
|
||||
创建系统提示模板,可选择动态注入工具描述。
|
||||
"""
|
||||
tools_section = ""
|
||||
if tools:
|
||||
tool_descs = []
|
||||
for tool in tools:
|
||||
# 提取工具名称和描述的第一行
|
||||
name = getattr(tool, 'name', None) or getattr(tool, '__name__', 'unknown_tool')
|
||||
desc = (tool.description or "").split('\n')[0]
|
||||
tool_descs.append(f"- {name}: {desc}")
|
||||
tools_section = "\n".join(tool_descs)
|
||||
|
||||
def create_system_prompt() -> ChatPromptTemplate:
|
||||
"""
|
||||
创建系统提示模板
|
||||
|
||||
Returns:
|
||||
ChatPromptTemplate: 包含系统指令和消息占位符的提示模板
|
||||
"""
|
||||
system_template = (
|
||||
"你是一个个人生活助手和数据分析助手,请使用中文交流。\n\n"
|
||||
"【用户背景信息】\n"
|
||||
@@ -20,15 +22,11 @@ def create_system_prompt() -> ChatPromptTemplate:
|
||||
"{memory_context}\n"
|
||||
"若包含姓名、偏好等个人信息,请自然融入回应(例如称呼名字、提及偏好)。\n\n"
|
||||
"【可用工具与使用规则】\n"
|
||||
"- 获取温度/天气:`get_current_temperature`\n"
|
||||
"- 读取文本文件:`read_local_file`(限定目录 `./user_docs`)\n"
|
||||
"- 读取PDF摘要:`read_pdf_summary`(限定目录 `./user_docs`)\n"
|
||||
"- 读取Excel表格:`read_excel_as_markdown`(限定目录 `./user_docs`)\n"
|
||||
"- 抓取网页内容:`fetch_webpage_content`\n"
|
||||
f"{tools_section}\n"
|
||||
"工具调用时请直接返回所需参数,无需额外说明。\n\n"
|
||||
"【回答要求(必须遵守)】\n"
|
||||
"1. 回答必须简洁、直接。\n"
|
||||
"2. 如果你认为该问题需要进行深入的推理或思考,请务必将你的思维链或推理过程用 `<think>` 和 `</think>` 标签包裹起来,放在回答的最前面。例如:<think>这里是我的思考过程...</think>这里是最终回答。\n"
|
||||
"2. 如果你认为该问题需要进行深入的推理或思考,请务必将你的思维链或推理过程用 `<think>` 和 `</think>` 标签包裹起来,放在回答的最前面。\n"
|
||||
"3. 优先利用已知用户信息进行个性化回复。\n"
|
||||
"4. 若无信息可依,礼貌询问或提供通用帮助。"
|
||||
)
|
||||
@@ -36,4 +34,4 @@ def create_system_prompt() -> ChatPromptTemplate:
|
||||
return ChatPromptTemplate.from_messages([
|
||||
("system", system_template),
|
||||
MessagesPlaceholder(variable_name="messages")
|
||||
])
|
||||
])
|
||||
23
app/agent/rag_initializer.py
Normal file
23
app/agent/rag_initializer.py
Normal file
@@ -0,0 +1,23 @@
|
||||
# app/rag_initializer.py
|
||||
from app.rag.tools import create_rag_tool_sync
|
||||
from rag_core import create_parent_retriever
|
||||
from app.logger import info, warning
|
||||
|
||||
async def init_rag_tool(local_llm_creator):
|
||||
"""初始化 RAG 工具,失败返回 None"""
|
||||
try:
|
||||
info("🔄 正在初始化 RAG 检索系统...")
|
||||
retriever = create_parent_retriever(
|
||||
collection_name="rag_documents",
|
||||
search_k=5,
|
||||
)
|
||||
rewrite_llm = local_llm_creator()
|
||||
rag_tool = create_rag_tool_sync(
|
||||
retriever, rewrite_llm,
|
||||
num_queries=3, rerank_top_n=5
|
||||
)
|
||||
info("✅ RAG 检索工具初始化成功")
|
||||
return rag_tool
|
||||
except Exception as e:
|
||||
warning(f"⚠️ RAG 检索工具初始化失败: {e}")
|
||||
return None
|
||||
156
app/agent/service.py
Normal file
156
app/agent/service.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""
|
||||
AI Agent 服务类 - 支持多模型动态切换
|
||||
接收外部传入的 checkpointer,不负责管理连接生命周期
|
||||
"""
|
||||
|
||||
import json
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 本地模块
|
||||
from app.graph.graph_builder import GraphBuilder, GraphContext
|
||||
from app.graph.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
|
||||
from app.agent.llm_factory import LLMFactory
|
||||
from app.agent.rag_initializer import init_rag_tool
|
||||
from app.logger import info, warning
|
||||
load_dotenv()
|
||||
|
||||
class AIAgentService:
|
||||
def __init__(self, checkpointer):
|
||||
self.checkpointer = checkpointer
|
||||
self.graphs = {}
|
||||
self.tools = AVAILABLE_TOOLS.copy()
|
||||
self.tools_by_name = TOOLS_BY_NAME.copy()
|
||||
|
||||
async def initialize(self):
|
||||
# 1. 初始化 RAG 工具(如果需要)
|
||||
rag_tool = await init_rag_tool(LLMFactory.create_local)
|
||||
if rag_tool:
|
||||
self.tools.append(rag_tool)
|
||||
self.tools_by_name[rag_tool.name] = rag_tool
|
||||
|
||||
# 2. 构建各模型的 Graph
|
||||
for name, creator in LLMFactory.CREATORS.items():
|
||||
try:
|
||||
info(f"🔄 初始化模型 '{name}'...")
|
||||
llm = creator()
|
||||
builder = GraphBuilder(llm, self.tools, self.tools_by_name).build()
|
||||
graph = builder.compile(checkpointer=self.checkpointer)
|
||||
self.graphs[name] = graph
|
||||
info(f"✅ 模型 '{name}' 初始化成功")
|
||||
except Exception as e:
|
||||
warning(f"⚠️ 模型 '{name}' 初始化失败: {e}")
|
||||
|
||||
if not self.graphs:
|
||||
raise RuntimeError("没有可用的模型")
|
||||
return self
|
||||
|
||||
async def process_message(self, message: str, thread_id: str, model: str = "local", user_id: str = "default_user") -> dict:
|
||||
"""处理用户消息,返回包含回复、token统计和耗时的字典"""
|
||||
if model not in self.graphs:
|
||||
# 回退到第一个可用模型
|
||||
available = list(self.graphs.keys())
|
||||
if not available:
|
||||
raise RuntimeError("没有可用的模型")
|
||||
model = available[0]
|
||||
warning(f"模型 '{model}' 不可用,已回退到 '{model}'")
|
||||
|
||||
graph = self.graphs[model]
|
||||
config = {
|
||||
"configurable": {"thread_id": thread_id},
|
||||
"metadata": {"user_id": user_id}
|
||||
}
|
||||
input_state = {"messages": [{"role": "user", "content": message}]}
|
||||
context = GraphContext(user_id=user_id)
|
||||
|
||||
result = await graph.ainvoke(input_state, config=config, context=context)
|
||||
|
||||
reply = result["messages"][-1].content
|
||||
token_usage = result.get("last_token_usage", {})
|
||||
elapsed_time = result.get("last_elapsed_time", 0.0)
|
||||
|
||||
return {
|
||||
"reply": reply,
|
||||
"token_usage": token_usage,
|
||||
"elapsed_time": elapsed_time
|
||||
}
|
||||
|
||||
def _serialize_value(self, value):
|
||||
"""递归将 LangChain 对象转换为可 JSON 序列化的格式"""
|
||||
if hasattr(value, 'content'):
|
||||
msg_type = getattr(value, 'type', 'message')
|
||||
return {
|
||||
"role": msg_type,
|
||||
"content": getattr(value, 'content', ''),
|
||||
"additional_kwargs": getattr(value, 'additional_kwargs', {}),
|
||||
"tool_calls": getattr(value, 'tool_calls', [])
|
||||
}
|
||||
elif isinstance(value, dict):
|
||||
return {k: self._serialize_value(v) for k, v in value.items()}
|
||||
elif isinstance(value, (list, tuple)):
|
||||
return [self._serialize_value(item) for item in value]
|
||||
else:
|
||||
try:
|
||||
json.dumps(value)
|
||||
return value
|
||||
except (TypeError, ValueError):
|
||||
return str(value)
|
||||
|
||||
async def process_message_stream(self, message: str, thread_id: str, model_name: str, user_id: str = "default_user"):
|
||||
"""流式处理消息,返回异步生成器"""
|
||||
graph = self.graphs.get(model_name)
|
||||
if not graph:
|
||||
raise ValueError(f"模型 '{model_name}' 未找到或未初始化")
|
||||
|
||||
config = {
|
||||
"configurable": {"thread_id": thread_id},
|
||||
"metadata": {"user_id": user_id}
|
||||
}
|
||||
input_state = {"messages": [{"role": "user", "content": message}]}
|
||||
context = GraphContext(user_id=user_id)
|
||||
|
||||
async for chunk in graph.astream(
|
||||
input_state,
|
||||
config=config,
|
||||
context=context,
|
||||
stream_mode=["messages", "updates", "custom"],
|
||||
version="v2",
|
||||
subgraphs=True
|
||||
):
|
||||
chunk_type = chunk["type"]
|
||||
processed_event = {}
|
||||
|
||||
if chunk_type == "messages":
|
||||
message_chunk, metadata = chunk["data"]
|
||||
node_name = metadata.get("langgraph_node", "unknown")
|
||||
token_content = getattr(message_chunk, 'content', str(message_chunk))
|
||||
reasoning_token = ""
|
||||
if hasattr(message_chunk, 'additional_kwargs'):
|
||||
reasoning_token = message_chunk.additional_kwargs.get("reasoning_content", "")
|
||||
|
||||
processed_event = {
|
||||
"type": "llm_token",
|
||||
"node": node_name,
|
||||
"token": token_content,
|
||||
"reasoning_token": reasoning_token,
|
||||
"metadata": metadata
|
||||
}
|
||||
elif chunk_type == "updates":
|
||||
updates_data = chunk["data"]
|
||||
serialized_data = self._serialize_value(updates_data)
|
||||
processed_event = {
|
||||
"type": "state_update",
|
||||
"data": serialized_data
|
||||
}
|
||||
if "messages" in serialized_data:
|
||||
processed_event["messages"] = serialized_data["messages"]
|
||||
elif chunk_type == "custom":
|
||||
serialized_data = self._serialize_value(chunk["data"])
|
||||
processed_event = {
|
||||
"type": "custom",
|
||||
"data": serialized_data
|
||||
}
|
||||
else:
|
||||
continue
|
||||
|
||||
if processed_event:
|
||||
yield processed_event
|
||||
@@ -15,8 +15,8 @@ from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
from app.agent import AIAgentService
|
||||
from app.history import ThreadHistoryService
|
||||
from app.logger import debug, info, warning, error
|
||||
from app.agent.history import ThreadHistoryService
|
||||
from app.logger import info, error
|
||||
|
||||
# 加载 .env 文件
|
||||
load_dotenv()
|
||||
@@ -28,7 +28,6 @@ DB_URI = os.getenv(
|
||||
"postgresql://postgres:huang1998@ai-postgres:5432/langgraph_db?sslmode=disable"
|
||||
)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""应用生命周期管理:创建并注入全局服务"""
|
||||
@@ -53,7 +52,6 @@ async def lifespan(app: FastAPI):
|
||||
# 5. 关闭时自动清理数据库连接(async with 负责)
|
||||
info("🛑 应用关闭,数据库连接池已释放")
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
# CORS 中间件(允许前端跨域)
|
||||
@@ -65,14 +63,12 @@ app.add_middleware(
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
# ========== 健康检查端点 ==========
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""健康检查端点,用于 Docker 和 CI/CD 监控"""
|
||||
return {"status": "ok", "service": "ai-agent-backend"}
|
||||
|
||||
|
||||
# ========== Pydantic 模型 ==========
|
||||
class ChatRequest(BaseModel):
|
||||
message: str
|
||||
@@ -80,7 +76,6 @@ class ChatRequest(BaseModel):
|
||||
model: str = "zhipu"
|
||||
user_id: str = "default_user"
|
||||
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
reply: str
|
||||
thread_id: str
|
||||
@@ -90,18 +85,15 @@ class ChatResponse(BaseModel):
|
||||
total_tokens: int = 0
|
||||
elapsed_time: float = 0.0
|
||||
|
||||
|
||||
# ========== 依赖注入函数 ==========
|
||||
def get_agent_service(request: Request) -> AIAgentService:
|
||||
"""从 app.state 中获取全局 AIAgentService 实例"""
|
||||
return request.app.state.agent_service
|
||||
|
||||
|
||||
def get_history_service(request: Request) -> ThreadHistoryService:
|
||||
"""从 app.state 中获取全局 ThreadHistoryService 实例"""
|
||||
return request.app.state.history_service
|
||||
|
||||
|
||||
# ========== HTTP 端点 ==========
|
||||
@app.post("/chat", response_model=ChatResponse)
|
||||
async def chat_endpoint(
|
||||
@@ -135,7 +127,6 @@ async def chat_endpoint(
|
||||
elapsed_time=elapsed_time
|
||||
)
|
||||
|
||||
|
||||
# ========== 历史查询接口 ==========
|
||||
@app.get("/threads")
|
||||
async def list_threads(
|
||||
@@ -147,7 +138,6 @@ async def list_threads(
|
||||
threads = await history_service.get_user_threads(user_id, limit)
|
||||
return {"threads": threads}
|
||||
|
||||
|
||||
@app.get("/thread/{thread_id}/messages")
|
||||
async def get_thread_messages(
|
||||
thread_id: str,
|
||||
@@ -158,7 +148,6 @@ async def get_thread_messages(
|
||||
messages = await history_service.get_thread_messages(thread_id)
|
||||
return {"messages": messages}
|
||||
|
||||
|
||||
@app.get("/thread/{thread_id}/summary")
|
||||
async def get_thread_summary(
|
||||
thread_id: str,
|
||||
@@ -169,7 +158,6 @@ async def get_thread_summary(
|
||||
summary = await history_service.get_thread_summary(thread_id)
|
||||
return summary
|
||||
|
||||
|
||||
# ========== 流式对话接口 ==========
|
||||
@app.post("/chat/stream")
|
||||
async def chat_stream_endpoint(
|
||||
@@ -204,7 +192,6 @@ async def chat_stream_endpoint(
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# ========== WebSocket 端点(可选) ==========
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(
|
||||
@@ -228,9 +215,8 @@ async def websocket_endpoint(
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
# 使用环境变量或默认端口 8083(避免与 llama.cpp 的 8081 端口冲突)
|
||||
port = int(os.getenv("BACKEND_PORT", "8083"))
|
||||
# 使用环境变量或默认端口 8079(避免与 llama.cpp 的 8081 端口冲突)
|
||||
port = int(os.getenv("BACKEND_PORT", "8079"))
|
||||
uvicorn.run(app, host="0.0.0.0", port=port)
|
||||
|
||||
@@ -18,6 +18,12 @@ MEMORY_SUMMARIZE_INTERVAL = int(os.getenv("MEMORY_SUMMARIZE_INTERVAL", "10"))
|
||||
# Qdrant 向量数据库地址
|
||||
QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333")
|
||||
QDRANT_COLLECTION_NAME = os.getenv("QDRANT_COLLECTION_NAME", "mem0_user_memories")
|
||||
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY", "your-qdrant-api-key")
|
||||
|
||||
# ========== llm 配置 ==========
|
||||
# LLM 模型配置
|
||||
VLLM_BASE_URL = os.getenv("VLLM_BASE_URL", "http://127.0.0.1:8081/v1")
|
||||
LLM_API_KEY = os.getenv("LLM_API_KEY", "your-ai-api-key")
|
||||
|
||||
# llama.cpp Embedding 服务地址 (用于 Mem0 的向量化)
|
||||
LLAMACPP_EMBEDDING_URL = os.getenv("LLAMACPP_EMBEDDING_URL", "http://127.0.0.1:8082/v1")
|
||||
|
||||
8
app/graph/__init__.py
Normal file
8
app/graph/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""
|
||||
Graph 子模块
|
||||
"""
|
||||
|
||||
from app.graph.graph_builder import GraphBuilder
|
||||
from app.graph.state import MessagesState, GraphContext
|
||||
|
||||
__all__ = ["GraphBuilder", "MessagesState", "GraphContext"]
|
||||
@@ -5,18 +5,17 @@ LangGraph 状态图构建模块 - 精简版,仅负责组装图
|
||||
|
||||
from langchain_core.language_models import BaseLLM
|
||||
from langgraph.graph import StateGraph, START, END
|
||||
|
||||
# 本地模块
|
||||
from app.state import MessagesState, GraphContext
|
||||
from app.graph.state import MessagesState, GraphContext
|
||||
from app.nodes import (
|
||||
should_continue,
|
||||
create_llm_call_node,
|
||||
create_tool_call_node,
|
||||
create_retrieve_memory_node,
|
||||
create_summarize_node,
|
||||
should_continue
|
||||
finalize_node,
|
||||
)
|
||||
from app.nodes.memory_trigger import memory_trigger_node, set_mem0_client
|
||||
from app.memory import Mem0Client
|
||||
from app.nodes.finalize import finalize_node
|
||||
|
||||
|
||||
class GraphBuilder:
|
||||
@@ -45,6 +44,9 @@ class GraphBuilder:
|
||||
Returns:
|
||||
StateGraph 实例
|
||||
"""
|
||||
# 注入全局客户端
|
||||
set_mem0_client(self.mem0_client)
|
||||
|
||||
builder = StateGraph(MessagesState, context_schema=GraphContext)
|
||||
|
||||
# ⭐ 通过工厂函数创建节点(依赖注入)
|
||||
@@ -55,6 +57,7 @@ class GraphBuilder:
|
||||
|
||||
# 添加节点
|
||||
builder.add_node("retrieve_memory", retrieve_memory_node)
|
||||
builder.add_node("memory_trigger", memory_trigger_node)
|
||||
builder.add_node("llm_call", llm_call_node)
|
||||
builder.add_node("tool_node", tool_call_node)
|
||||
builder.add_node("summarize", summarize_node)
|
||||
@@ -62,7 +65,8 @@ class GraphBuilder:
|
||||
|
||||
# 添加边
|
||||
builder.add_edge(START, "retrieve_memory")
|
||||
builder.add_edge("retrieve_memory", "llm_call")
|
||||
builder.add_edge("retrieve_memory", "memory_trigger")
|
||||
builder.add_edge("memory_trigger", "llm_call")
|
||||
builder.add_conditional_edges(
|
||||
"llm_call",
|
||||
should_continue,
|
||||
@@ -3,7 +3,6 @@
|
||||
"""
|
||||
|
||||
# 标准库
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# 第三方库
|
||||
@@ -13,7 +12,6 @@ import requests
|
||||
from bs4 import BeautifulSoup
|
||||
from langchain_core.tools import tool
|
||||
|
||||
|
||||
def _file_allow_check(filename: str) -> Path:
|
||||
"""检查用户文件名是否位于允许目录 './user_docs' 下,防止路径遍历攻击。"""
|
||||
allowed_dir = Path("./user_docs").resolve()
|
||||
@@ -28,13 +26,11 @@ def _file_allow_check(filename: str) -> Path:
|
||||
|
||||
return file_path
|
||||
|
||||
|
||||
@tool
|
||||
def get_current_temperature(location: str) -> str:
|
||||
"""获取指定地点的当前温度。"""
|
||||
return f'当前{location}的温度为25℃'
|
||||
|
||||
|
||||
@tool
|
||||
def read_local_file(filename: str) -> str:
|
||||
"""读取用户指定名称的本地文本文件内容并返回摘要。"""
|
||||
@@ -46,7 +42,6 @@ def read_local_file(filename: str) -> str:
|
||||
except Exception as e:
|
||||
return f"读取文件时出错:{str(e)}"
|
||||
|
||||
|
||||
@tool
|
||||
def read_pdf_summary(filename: str) -> str:
|
||||
"""读取PDF文件并返回内容文本摘要。"""
|
||||
@@ -61,7 +56,6 @@ def read_pdf_summary(filename: str) -> str:
|
||||
except Exception as e:
|
||||
return f"读取PDF出错:{e}"
|
||||
|
||||
|
||||
@tool
|
||||
def read_excel_as_markdown(filename: str) -> str:
|
||||
"""读取Excel文件,并将其主要数据转换为Markdown表格格式。"""
|
||||
@@ -73,7 +67,6 @@ def read_excel_as_markdown(filename: str) -> str:
|
||||
except Exception as e:
|
||||
return f"读取Excel出错:{e}"
|
||||
|
||||
|
||||
@tool
|
||||
def fetch_webpage_content(url: str) -> str:
|
||||
"""抓取给定URL的网页正文内容,并返回清晰的纯文本。"""
|
||||
@@ -91,7 +84,6 @@ def fetch_webpage_content(url: str) -> str:
|
||||
except Exception as e:
|
||||
return f"抓取网页时出错:{str(e)}"
|
||||
|
||||
|
||||
# 工具列表和映射(全局常量)
|
||||
AVAILABLE_TOOLS = [
|
||||
get_current_temperature,
|
||||
@@ -4,15 +4,13 @@
|
||||
"""
|
||||
|
||||
from typing import Any, Dict
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
# 本地模块
|
||||
from app.state import MessagesState, GraphContext
|
||||
from app.graph.state import MessagesState
|
||||
from app.memory.mem0_client import Mem0Client
|
||||
from app.utils.logging import log_state_change
|
||||
from app.logger import debug
|
||||
|
||||
|
||||
def create_retrieve_memory_node(mem0_client: Mem0Client):
|
||||
"""
|
||||
工厂函数:创建记忆检索节点
|
||||
@@ -4,12 +4,11 @@ LangGraph 状态定义模块
|
||||
"""
|
||||
|
||||
import operator
|
||||
from typing import Annotated, Any
|
||||
from typing import Annotated
|
||||
from typing_extensions import TypedDict
|
||||
from dataclasses import dataclass
|
||||
from langchain_core.messages import AnyMessage
|
||||
|
||||
|
||||
class MessagesState(TypedDict):
|
||||
"""对话状态类型定义"""
|
||||
messages: Annotated[list[AnyMessage], operator.add]
|
||||
@@ -19,7 +18,6 @@ class MessagesState(TypedDict):
|
||||
last_elapsed_time: float # 本次调用耗时(秒)
|
||||
turns_since_last_summary: int # 距离上次生成摘要的轮数
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraphContext:
|
||||
"""图执行上下文"""
|
||||
@@ -1,16 +1,22 @@
|
||||
from app.config import LLM_API_KEY
|
||||
from app.config import VLLM_BASE_URL
|
||||
import time
|
||||
"""
|
||||
Mem0 记忆层客户端封装模块
|
||||
负责 Mem0 的初始化、检索和存储
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Optional, List, Dict, Any
|
||||
from typing import Optional, List, Dict
|
||||
from mem0 import AsyncMemory
|
||||
|
||||
from app.config import QDRANT_URL, QDRANT_COLLECTION_NAME, LLAMACPP_EMBEDDING_URL, LLAMACPP_API_KEY
|
||||
from app.config import (
|
||||
QDRANT_URL,QDRANT_COLLECTION_NAME,QDRANT_API_KEY,
|
||||
VLLM_BASE_URL, LLM_API_KEY,
|
||||
LLAMACPP_EMBEDDING_URL, LLAMACPP_API_KEY
|
||||
)
|
||||
from app.logger import info, warning, error
|
||||
|
||||
|
||||
class Mem0Client:
|
||||
"""Mem0 异步客户端封装类"""
|
||||
|
||||
@@ -37,20 +43,25 @@ class Mem0Client:
|
||||
"provider": "qdrant",
|
||||
"config": {
|
||||
"url": QDRANT_URL, # 直接使用完整 URL
|
||||
"api_key": QDRANT_API_KEY,
|
||||
"collection_name": QDRANT_COLLECTION_NAME,
|
||||
"embedding_model_dims": 768,
|
||||
"embedding_model_dims": 1024,
|
||||
}
|
||||
},
|
||||
"llm": {
|
||||
"provider": "langchain",
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"model": self.llm
|
||||
"model": "LLM_MODEL",
|
||||
"api_key": LLM_API_KEY,
|
||||
"openai_base_url": VLLM_BASE_URL,
|
||||
"temperature": 0.1,
|
||||
"max_tokens": 2000,
|
||||
}
|
||||
},
|
||||
"embedder": {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"model": "embeddinggemma-300M-Q8_0",
|
||||
"model": "Qwen3-Embedding-0.6B-Q8_0",
|
||||
"api_key": LLAMACPP_API_KEY,
|
||||
"openai_base_url": LLAMACPP_EMBEDDING_URL,
|
||||
},
|
||||
@@ -118,36 +129,18 @@ class Mem0Client:
|
||||
warning(f"⚠️ Mem0 检索失败: {e}")
|
||||
return []
|
||||
|
||||
async def add_memories(self, messages: List[Dict[str, str]], user_id: str) -> bool:
|
||||
"""
|
||||
添加记忆(自动提取事实并存储)
|
||||
|
||||
Args:
|
||||
messages: 消息列表,格式为 [{"role": "user/assistant/system", "content": "..."}]
|
||||
user_id: 用户 ID
|
||||
|
||||
Returns:
|
||||
bool: 是否成功
|
||||
"""
|
||||
if not self.mem0:
|
||||
warning("⚠️ Mem0 未初始化,跳过记忆添加")
|
||||
return False
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self.mem0.add(
|
||||
messages,
|
||||
user_id=user_id,
|
||||
metadata={"type": "conversation"}
|
||||
),
|
||||
timeout=60.0
|
||||
)
|
||||
info("📝 [记忆添加] 已提交给 Mem0 进行事实提取")
|
||||
return True
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
error("❌ Mem0 记忆添加超时 (60s)")
|
||||
return False
|
||||
except Exception as e:
|
||||
error(f"❌ Mem0 记忆添加失败: {e}")
|
||||
return False
|
||||
async def add_memories(self, messages, user_id):
|
||||
if not self.mem0:
|
||||
return False
|
||||
try:
|
||||
start = time.time()
|
||||
info(f"📝 开始 Mem0 add,消息数: {len(messages)}")
|
||||
await asyncio.wait_for(
|
||||
self.mem0.add(messages, user_id=user_id, metadata={"type": "conversation"}),
|
||||
timeout=60.0
|
||||
)
|
||||
info(f"✅ Mem0 add 完成,耗时: {time.time() - start:.2f}s")
|
||||
return True
|
||||
except asyncio.TimeoutError:
|
||||
error(f"❌ Mem0 记忆添加超时 (60s),已等待 {time.time() - start:.2f}s")
|
||||
return False
|
||||
@@ -5,7 +5,7 @@
|
||||
from app.nodes.router import should_continue
|
||||
from app.nodes.llm_call import create_llm_call_node
|
||||
from app.nodes.tool_call import create_tool_call_node
|
||||
from app.nodes.retrieve_memory import create_retrieve_memory_node
|
||||
from app.graph.retrieve_memory import create_retrieve_memory_node
|
||||
from app.nodes.summarize import create_summarize_node
|
||||
from app.nodes.finalize import finalize_node
|
||||
|
||||
|
||||
@@ -4,15 +4,13 @@
|
||||
"""
|
||||
|
||||
from typing import Any, Dict
|
||||
from langgraph.runtime import Runtime
|
||||
from langgraph.config import get_stream_writer
|
||||
|
||||
# 本地模块
|
||||
from app.state import MessagesState, GraphContext
|
||||
from app.graph.state import MessagesState
|
||||
from app.utils.logging import log_state_change
|
||||
from app.logger import info, error
|
||||
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
async def finalize_node(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]:
|
||||
|
||||
@@ -3,21 +3,17 @@ LLM 调用节点模块
|
||||
负责调用大语言模型并处理响应
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any, Dict
|
||||
from langchain_core.language_models import BaseLLM
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.runnables import RunnableLambda
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
# 本地模块
|
||||
from app.state import MessagesState, GraphContext
|
||||
from app.prompts import create_system_prompt
|
||||
from app.utils.logging import log_state_change, print_llm_input
|
||||
from app.graph.state import MessagesState
|
||||
from app.agent.prompts import create_system_prompt
|
||||
from app.utils.logging import log_state_change
|
||||
from app.logger import debug, info, error
|
||||
|
||||
|
||||
def create_llm_call_node(llm: BaseLLM, tools: list):
|
||||
"""
|
||||
工厂函数:创建 LLM 调用节点
|
||||
@@ -30,7 +26,7 @@ def create_llm_call_node(llm: BaseLLM, tools: list):
|
||||
异步节点函数
|
||||
"""
|
||||
# 构建调用链
|
||||
prompt = create_system_prompt()
|
||||
prompt = create_system_prompt(tools)
|
||||
llm_with_tools = llm.bind_tools(tools)
|
||||
|
||||
# 恢复带 RunnableLambda 的链,并在下方使用 astream 遍历
|
||||
|
||||
38
app/nodes/memory_trigger.py
Normal file
38
app/nodes/memory_trigger.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from typing import Any, Dict
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from app.graph.state import MessagesState
|
||||
from app.memory.mem0_client import Mem0Client
|
||||
from app.logger import info
|
||||
|
||||
# 全局变量,在 GraphBuilder 中注入
|
||||
_mem0_client: Mem0Client = None
|
||||
|
||||
def set_mem0_client(client: Mem0Client):
|
||||
global _mem0_client
|
||||
_mem0_client = client
|
||||
|
||||
async def memory_trigger_node(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]:
|
||||
"""检测用户消息中的记忆指令,若命中则主动调用 Mem0 存储"""
|
||||
if _mem0_client is None:
|
||||
return {}
|
||||
|
||||
messages = state.get("messages", [])
|
||||
if not messages:
|
||||
return {}
|
||||
|
||||
last_msg = messages[-1]
|
||||
content = last_msg.content if hasattr(last_msg, 'content') else str(last_msg)
|
||||
|
||||
# 触发词(可自行扩展)
|
||||
trigger_words = ["记住", "记下", "保存", "备忘", "记录下", "别忘了"]
|
||||
if any(word in content for word in trigger_words):
|
||||
user_id = config.get("metadata", {}).get("user_id", "default_user")
|
||||
# 确保 Mem0 已初始化
|
||||
if not _mem0_client._initialized:
|
||||
await _mem0_client.initialize()
|
||||
# 将用户消息作为事实来源提交给 Mem0
|
||||
info(f"📌 检测到记忆指令,已主动触发 Mem0 存储")
|
||||
mem0_messages = [{"role": "user", "content": content}]
|
||||
await _mem0_client.add_memories(mem0_messages, user_id=user_id)
|
||||
|
||||
return {} # 不修改状态
|
||||
@@ -8,7 +8,7 @@ from langchain_core.messages import AIMessage
|
||||
|
||||
# 本地模块
|
||||
from app.config import ENABLE_GRAPH_TRACE, MEMORY_SUMMARIZE_INTERVAL
|
||||
from app.state import MessagesState
|
||||
from app.graph.state import MessagesState
|
||||
from app.logger import info
|
||||
|
||||
|
||||
|
||||
@@ -4,15 +4,13 @@
|
||||
"""
|
||||
|
||||
from typing import Any, Dict
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
# 本地模块
|
||||
from app.state import MessagesState, GraphContext
|
||||
from app.graph.state import MessagesState
|
||||
from app.memory.mem0_client import Mem0Client
|
||||
from app.utils.logging import log_state_change
|
||||
from app.logger import debug, info, error, warning
|
||||
|
||||
|
||||
def create_summarize_node(mem0_client: Mem0Client):
|
||||
"""
|
||||
工厂函数:创建记忆存储节点
|
||||
|
||||
@@ -6,15 +6,13 @@
|
||||
import asyncio
|
||||
from typing import Any, Dict
|
||||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
from langgraph.runtime import Runtime
|
||||
from langgraph.config import get_stream_writer
|
||||
|
||||
# 本地模块
|
||||
from app.state import MessagesState, GraphContext
|
||||
from app.graph.state import MessagesState
|
||||
from app.utils.logging import log_state_change
|
||||
from app.logger import debug, info
|
||||
|
||||
|
||||
def create_tool_call_node(tools_by_name: Dict[str, Any]):
|
||||
"""
|
||||
工厂函数:创建工具执行节点
|
||||
|
||||
@@ -13,7 +13,7 @@ RAG 检索与生成模块
|
||||
用户查询 → 多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档
|
||||
|
||||
示例用法:
|
||||
>>> from app.rag import RAGPipeline, create_rag_tool
|
||||
>>> from app.rag.rag import RAGPipeline, create_rag_tool
|
||||
>>> from rag_indexer.builder import IndexBuilder, IndexBuilderConfig
|
||||
>>> from langchain_openai import ChatOpenAI
|
||||
>>>
|
||||
@@ -34,16 +34,16 @@ RAG 检索与生成模块
|
||||
>>> rag_tool = create_rag_tool(retriever=retriever, llm=llm)
|
||||
"""
|
||||
|
||||
from .retriever import (
|
||||
from app.rag.retriever import (
|
||||
create_base_retriever,
|
||||
create_hybrid_retriever,
|
||||
create_qdrant_client,
|
||||
)
|
||||
from .reranker import CrossEncoderReranker
|
||||
from .query_transform import MultiQueryGenerator
|
||||
from .fusion import reciprocal_rank_fusion
|
||||
from .pipeline import RAGPipeline
|
||||
from .tools import create_rag_tool, create_rag_tool_sync
|
||||
from app.rag.reranker import LLaMaCPPReranker
|
||||
from app.rag.query_transform import MultiQueryGenerator
|
||||
from app.rag.fusion import reciprocal_rank_fusion
|
||||
from app.rag.pipeline import RAGPipeline
|
||||
from app.rag.tools import create_rag_tool_sync
|
||||
|
||||
|
||||
__all__ = [
|
||||
@@ -53,7 +53,7 @@ __all__ = [
|
||||
"create_qdrant_client",
|
||||
|
||||
# 重排序器
|
||||
"CrossEncoderReranker",
|
||||
"LLaMaCPPReranker",
|
||||
|
||||
# 查询改写生成器
|
||||
"MultiQueryGenerator",
|
||||
@@ -65,6 +65,5 @@ __all__ = [
|
||||
"RAGPipeline",
|
||||
|
||||
# 工具创建(供 Agent 使用)
|
||||
"create_rag_tool",
|
||||
"create_rag_tool_sync",
|
||||
]
|
||||
@@ -1,6 +1,6 @@
|
||||
# rag/fusion.py
|
||||
|
||||
from typing import List, Dict, Tuple
|
||||
from typing import List, Dict
|
||||
from langchain_core.documents import Document
|
||||
|
||||
def reciprocal_rank_fusion(
|
||||
|
||||
@@ -1,15 +1,14 @@
|
||||
# rag/pipeline.py
|
||||
|
||||
import asyncio
|
||||
from typing import List, Optional
|
||||
import os
|
||||
from typing import List
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
|
||||
from .retriever import create_qdrant_client # 可能不需要直接使用
|
||||
from .reranker import LLaMaCPPReranker
|
||||
from .query_transform import MultiQueryGenerator
|
||||
from .fusion import reciprocal_rank_fusion
|
||||
|
||||
from app.rag.reranker import LLaMaCPPReranker
|
||||
from app.rag.query_transform import MultiQueryGenerator
|
||||
from app.rag.fusion import reciprocal_rank_fusion
|
||||
|
||||
class RAGPipeline:
|
||||
"""
|
||||
@@ -23,7 +22,6 @@ class RAGPipeline:
|
||||
llm: BaseLanguageModel,
|
||||
num_queries: int = 3,
|
||||
rerank_top_n: int = 5,
|
||||
rerank_model: str = "BAAI/bge-reranker-base",
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -41,9 +39,9 @@ class RAGPipeline:
|
||||
# 初始化组件
|
||||
self.query_generator = MultiQueryGenerator(llm=llm, num_queries=num_queries)
|
||||
self.reranker = LLaMaCPPReranker(
|
||||
base_url="http://127.0.0.1:8083",
|
||||
base_url=os.getenv("LLAMACPP_RERANKER_URL", "http://127.0.0.1:8083"),
|
||||
api_key=os.getenv("LLAMACPP_API_KEY", "huang1998"),
|
||||
top_n=rerank_top_n,
|
||||
api_key="huang1998"
|
||||
)
|
||||
|
||||
async def aretrieve(self, query: str) -> List[Document]:
|
||||
@@ -68,9 +66,9 @@ class RAGPipeline:
|
||||
fused_docs = reciprocal_rank_fusion(doc_lists)
|
||||
|
||||
# Step 4: 重排序
|
||||
if self.reranker.model is not None:
|
||||
try:
|
||||
final_docs = self.reranker.compress_documents(fused_docs, query)
|
||||
else:
|
||||
except Exception:
|
||||
# 若重排序器不可用,直接返回融合后的前 N 条
|
||||
final_docs = fused_docs[:self.rerank_top_n]
|
||||
|
||||
|
||||
@@ -10,24 +10,24 @@ class LLaMaCPPReranker:
|
||||
"""使用远程 llama.cpp 服务对检索结果重排序。"""
|
||||
|
||||
def __init__(self,
|
||||
base_url: str = "http://127.0.0.1:8083",
|
||||
base_url: str,
|
||||
api_key: str,
|
||||
top_n: int = 5,
|
||||
api_key: str = "huang1998", # 你设置的 LLAMA_ARG_API_KEY
|
||||
timeout: int = 60):
|
||||
"""
|
||||
初始化远程重排序器
|
||||
|
||||
Args:
|
||||
base_url: llama.cpp 服务的地址和端口。
|
||||
base_url: llama.cpp 服务的地址和端口,默认为环境变量 LLAMACPP_RERANKER_URL 或 "http://127.0.0.1:8083"。
|
||||
top_n: 返回前 N 个结果。
|
||||
api_key: 在容器中设置的 API 密钥。
|
||||
api_key: API 密钥,默认为环境变量 LLAMACPP_API_KEY 或 "huang1998"。
|
||||
timeout: 请求超时时间(秒)。
|
||||
"""
|
||||
self.base_url = base_url.rstrip('/')
|
||||
self.base_url = base_url
|
||||
self.api_key = api_key
|
||||
self.top_n = top_n
|
||||
self.api_key = api_key
|
||||
self.timeout = timeout
|
||||
self.endpoint = f"{self.base_url}/v1/rerank"
|
||||
self.endpoint = f"{self.base_url}/rerank"
|
||||
|
||||
def compress_documents(
|
||||
self, documents: List[Document], query: str
|
||||
|
||||
@@ -11,7 +11,6 @@ RAG 系统使用示例(重构版)
|
||||
import asyncio
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
@@ -19,12 +18,12 @@ from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
# 添加项目根目录到路径
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../.."))
|
||||
|
||||
from rag_indexer.index_builder import IndexBuilder, IndexBuilderConfig
|
||||
from rag_indexer.index_builder import IndexBuilderConfig
|
||||
from rag_indexer.splitters import SplitterType
|
||||
from rag.pipeline import RAGPipeline
|
||||
from rag.tools import create_rag_tool
|
||||
from app.rag.pipeline import RAGPipeline
|
||||
from app.rag.tools import create_rag_tool_sync
|
||||
from pydantic import SecretStr
|
||||
# 使用本地 LLM(通过 OpenAI 兼容接口)
|
||||
from langchain_openai import ChatOpenAI
|
||||
@@ -32,7 +31,6 @@ from rag_core.retriever_factory import create_parent_retriever
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
def create_llm():
|
||||
"""创建本地 vLLM 服务 LLM"""
|
||||
vllm_base_url = os.getenv(
|
||||
@@ -60,8 +58,7 @@ async def demonstrate_full_pipeline():
|
||||
print("演示:固定流程 RAG 检索(多路改写 + RRF + 重排序 + 父文档)")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
retriever = retriever = create_parent_retriever(collection_name="my_docs", search_k=5)
|
||||
retriever = create_parent_retriever(collection_name="rag_documents", search_k=5)
|
||||
|
||||
if retriever is None:
|
||||
print("错误:检索器未初始化,请确保索引已构建。")
|
||||
@@ -103,7 +100,6 @@ async def demonstrate_full_pipeline():
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
async def demonstrate_tool_creation():
|
||||
"""
|
||||
演示创建 RAG 工具(供 Agent 使用)
|
||||
@@ -119,12 +115,11 @@ async def demonstrate_tool_creation():
|
||||
)
|
||||
retriever = retriever = create_parent_retriever(collection_name="rag_documents", search_k=5)
|
||||
|
||||
|
||||
# 2. 创建 LLM
|
||||
llm = create_llm()
|
||||
|
||||
# 3. 创建工具
|
||||
rag_tool = create_rag_tool(
|
||||
rag_tool = create_rag_tool_sync(
|
||||
retriever=retriever,
|
||||
llm=llm,
|
||||
num_queries=3,
|
||||
@@ -136,18 +131,16 @@ async def demonstrate_tool_creation():
|
||||
print(f"工具描述: {rag_tool.description[:100]}...")
|
||||
|
||||
# 4. 模拟 Agent 调用工具
|
||||
query = "请告诉我 RAG 系统的核心组件有哪些?"
|
||||
query = "请告诉我 打虎英雄是谁?"
|
||||
print(f"\n模拟调用: {query}")
|
||||
print("-" * 40)
|
||||
|
||||
result = await rag_tool.ainvoke({"query": query})
|
||||
print(result[:800] + "..." if len(result) > 800 else result)
|
||||
|
||||
|
||||
async def main():
|
||||
await demonstrate_full_pipeline()
|
||||
await demonstrate_tool_creation()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -4,73 +4,11 @@ RAG 工具模块
|
||||
将检索功能封装为 LangChain Tool,供 Agent 调用。
|
||||
采用固定流水线:多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档。
|
||||
"""
|
||||
|
||||
from typing import Optional, Callable
|
||||
from typing import Callable
|
||||
from langchain_core.tools import tool
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
|
||||
from .pipeline import RAGPipeline
|
||||
|
||||
|
||||
def create_rag_tool(
|
||||
retriever: BaseRetriever,
|
||||
llm: BaseLanguageModel,
|
||||
num_queries: int = 3,
|
||||
rerank_top_n: int = 5,
|
||||
collection_name: str = "rag_documents",
|
||||
) -> Callable:
|
||||
"""
|
||||
创建一个配置好的 RAG 检索工具(异步)。
|
||||
|
||||
Args:
|
||||
retriever: 基础检索器(例如 ParentDocumentRetriever 实例)
|
||||
llm: 用于多路查询改写的语言模型
|
||||
num_queries: 生成查询变体数量
|
||||
rerank_top_n: 最终返回的文档数量
|
||||
collection_name: 集合名称(仅用于日志/描述)
|
||||
|
||||
Returns:
|
||||
LangChain Tool 可调用对象(异步)
|
||||
"""
|
||||
# 初始化流水线(所有组件一次创建,后续复用)
|
||||
pipeline = RAGPipeline(
|
||||
retriever=retriever,
|
||||
llm=llm,
|
||||
num_queries=num_queries,
|
||||
rerank_top_n=rerank_top_n,
|
||||
)
|
||||
|
||||
@tool
|
||||
async def search_knowledge_base(query: str) -> str:
|
||||
"""在知识库中搜索与查询相关的文档片段。
|
||||
|
||||
该工具会:
|
||||
1. 将用户问题改写成多个不同角度的查询
|
||||
2. 并行检索每个查询的相关父文档
|
||||
3. 使用倒数排名融合(RRF)合并结果
|
||||
4. 用 Cross-Encoder 重排序模型精选最相关的片段
|
||||
|
||||
适用于需要精确、全面答案的事实性问题或背景知识查询。
|
||||
|
||||
Args:
|
||||
query: 用户提出的问题或查询字符串
|
||||
|
||||
Returns:
|
||||
格式化后的相关文档内容,若无结果则返回提示信息。
|
||||
"""
|
||||
try:
|
||||
documents = await pipeline.aretrieve(query)
|
||||
if not documents:
|
||||
return f"在知识库 '{collection_name}' 中未找到与 '{query}' 相关的信息。"
|
||||
|
||||
context = pipeline.format_context(documents)
|
||||
return context
|
||||
except Exception as e:
|
||||
return f"检索过程中发生错误: {str(e)}"
|
||||
|
||||
return search_knowledge_base
|
||||
|
||||
from app.rag.pipeline import RAGPipeline
|
||||
|
||||
def create_rag_tool_sync(
|
||||
retriever: BaseRetriever,
|
||||
|
||||
307
app/test_backend.py
Normal file
307
app/test_backend.py
Normal file
@@ -0,0 +1,307 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
完整后端测试 - 验证 Agent 所有功能
|
||||
包括:短期记忆、长期记忆、工具调用、流式对话、历史查询
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 添加项目根目录到 Python 路径
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||
|
||||
load_dotenv()
|
||||
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
from app.agent import AIAgentService
|
||||
from app.agent.history import ThreadHistoryService
|
||||
from app.logger import info, warning, error
|
||||
|
||||
# PostgreSQL 连接字符串
|
||||
DB_URI = os.getenv(
|
||||
"DB_URI",
|
||||
"postgresql://postgres:***@ai-postgres:5432/langgraph_db?sslmode=disable"
|
||||
)
|
||||
|
||||
async def print_section(title):
|
||||
"""打印测试区块标题"""
|
||||
print("\n" + "=" * 70)
|
||||
print(f" {title}")
|
||||
print("=" * 70)
|
||||
|
||||
async def test_short_term_memory(agent_service):
|
||||
"""测试短期记忆(同一 thread_id 继续对话)"""
|
||||
await print_section("测试 1: 短期记忆(Short-term Memory)")
|
||||
|
||||
thread_id = str(uuid.uuid4())
|
||||
user_id = "test_user_memory"
|
||||
|
||||
print(f"\n使用 thread_id: {thread_id[:8]}...")
|
||||
print(f"使用 user_id: {user_id}")
|
||||
|
||||
# 第一轮对话
|
||||
print("\n[第一轮] 发送消息: '我叫张三,今年28岁'")
|
||||
result1 = await agent_service.process_message(
|
||||
"我叫张三,今年28岁", thread_id, "local", user_id
|
||||
)
|
||||
print(f"回复: {result1['reply'][:100]}...")
|
||||
|
||||
# 第二轮对话 - 测试记忆
|
||||
print("\n[第二轮] 发送消息: '我叫什么名字?今年多大?'")
|
||||
result2 = await agent_service.process_message(
|
||||
"我叫什么名字?今年多大?", thread_id, "local", user_id
|
||||
)
|
||||
print(f"回复: {result2['reply']}")
|
||||
|
||||
# 验证记忆是否存在
|
||||
if "张三" in result2['reply'] or "28" in result2['reply']:
|
||||
print("\n✅ 短期记忆测试通过!")
|
||||
return True
|
||||
else:
|
||||
print("\n❌ 短期记忆测试失败!")
|
||||
return False
|
||||
|
||||
async def test_tool_calling(agent_service):
|
||||
"""测试工具调用(RAG 搜索)"""
|
||||
await print_section("测试 2: 工具调用(Tool Calling)")
|
||||
|
||||
thread_id = str(uuid.uuid4())
|
||||
user_id = "test_user_tools"
|
||||
|
||||
print(f"\n使用 thread_id: {thread_id[:8]}...")
|
||||
print(f"使用 user_id: {user_id}")
|
||||
|
||||
# 发送需要 RAG 搜索的问题
|
||||
print("\n发送消息: '请告诉我,打虎英雄是谁?'")
|
||||
result = await agent_service.process_message(
|
||||
"请告诉我,打虎英雄是谁?", thread_id, "local", user_id
|
||||
)
|
||||
print(f"回复: {result['reply'][:200]}...")
|
||||
|
||||
# 检查是否调用了 RAG 工具(回复中会有水浒传相关内容)
|
||||
if "武松" in result['reply'] or "李忠" in result['reply'] or "水浒传" in result['reply']:
|
||||
print("\n✅ 工具调用测试通过!")
|
||||
return True
|
||||
else:
|
||||
print("\n⚠️ 工具调用测试结果不确定,需要手动验证")
|
||||
return None
|
||||
|
||||
async def test_streaming(agent_service):
|
||||
"""测试流式对话"""
|
||||
await print_section("测试 3: 流式对话(Streaming)")
|
||||
|
||||
thread_id = str(uuid.uuid4())
|
||||
user_id = "test_user_stream"
|
||||
|
||||
print(f"\n使用 thread_id: {thread_id[:8]}...")
|
||||
print(f"使用 user_id: {user_id}")
|
||||
|
||||
print("\n发送消息: '用100字介绍一下AI人工智能' (流式)...")
|
||||
print("流式输出: ", end="", flush=True)
|
||||
|
||||
full_reply = ""
|
||||
chunk_count = 0
|
||||
|
||||
try:
|
||||
async for chunk in agent_service.process_message_stream(
|
||||
"用100字介绍一下AI人工智能", thread_id, "local", user_id
|
||||
):
|
||||
chunk_count += 1
|
||||
if chunk.get("type") == "llm_token":
|
||||
token = chunk.get("token", "")
|
||||
print(token, end="", flush=True)
|
||||
full_reply += token
|
||||
elif chunk.get("type") == "state_update":
|
||||
pass # 状态更新不显示
|
||||
|
||||
print(f"\n\n共收到 {chunk_count} 个 chunk")
|
||||
print(f"完整回复长度: {len(full_reply)} 字")
|
||||
|
||||
if chunk_count > 0 and len(full_reply) > 10:
|
||||
print("\n✅ 流式对话测试通过!")
|
||||
return True
|
||||
else:
|
||||
print("\n❌ 流式对话测试失败!")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n❌ 流式对话异常: {e}")
|
||||
return False
|
||||
|
||||
async def test_history_service(agent_service, history_service):
|
||||
"""测试历史查询服务"""
|
||||
await print_section("测试 4: 历史查询服务(History Service)")
|
||||
|
||||
user_id = "test_user_history"
|
||||
|
||||
# 先创建几个对话
|
||||
print(f"\n为 user_id={user_id} 创建测试对话...")
|
||||
|
||||
thread_ids = []
|
||||
for i in range(3):
|
||||
thread_id = str(uuid.uuid4())
|
||||
thread_ids.append(thread_id)
|
||||
|
||||
await agent_service.process_message(
|
||||
f"这是第 {i+1} 个测试对话", thread_id, "local", user_id
|
||||
)
|
||||
print(f" 创建线程 {i+1}: {thread_id[:8]}...")
|
||||
|
||||
# 1. 测试获取用户线程列表
|
||||
print("\n[4.1] 测试获取用户线程列表...")
|
||||
threads = await history_service.get_user_threads(user_id, limit=10)
|
||||
print(f" 找到 {len(threads)} 个线程")
|
||||
|
||||
if len(threads) >= 3:
|
||||
print(" ✅ 线程列表查询通过")
|
||||
else:
|
||||
print(" ⚠️ 线程数量少于预期")
|
||||
|
||||
# 2. 测试获取单个线程的消息历史
|
||||
if thread_ids:
|
||||
test_thread_id = thread_ids[0]
|
||||
print(f"\n[4.2] 测试获取线程消息历史 (thread_id={test_thread_id[:8]}...)")
|
||||
messages = await history_service.get_thread_messages(test_thread_id)
|
||||
print(f" 找到 {len(messages)} 条消息")
|
||||
|
||||
if len(messages) >= 2: # 至少有一问一答
|
||||
print(" ✅ 消息历史查询通过")
|
||||
else:
|
||||
print(" ⚠️ 消息数量少于预期")
|
||||
|
||||
# 3. 测试获取线程摘要
|
||||
print(f"\n[4.3] 测试获取线程摘要...")
|
||||
summary = await history_service.get_thread_summary(test_thread_id)
|
||||
print(f" 摘要: {summary.get('summary', '')[:50]}...")
|
||||
print(f" 消息数: {summary.get('message_count', 0)}")
|
||||
|
||||
if summary.get('message_count', 0) > 0:
|
||||
print(" ✅ 线程摘要查询通过")
|
||||
else:
|
||||
print(" ⚠️ 摘要查询结果不确定")
|
||||
|
||||
return len(threads) >= 3
|
||||
|
||||
async def test_long_term_memory(agent_service):
|
||||
"""测试长期记忆(mem0)"""
|
||||
await print_section("测试 5: 长期记忆(Long-term Memory - mem0)")
|
||||
|
||||
thread_id1 = str(uuid.uuid4())
|
||||
thread_id2 = str(uuid.uuid4()) # 不同的线程
|
||||
user_id = "test_user_longterm"
|
||||
|
||||
print(f"\n使用 user_id: {user_id}")
|
||||
print(f"线程 1: {thread_id1[:8]}...")
|
||||
print(f"线程 2: {thread_id2[:8]}...")
|
||||
|
||||
# 在第一个线程中保存信息
|
||||
print("\n[线程 1] 发送消息: '记住,我的宠物名字叫小白,是一只猫'")
|
||||
result1 = await agent_service.process_message(
|
||||
"记住,我的宠物名字叫小白,是一只猫", thread_id1, "local", user_id
|
||||
)
|
||||
print(f"回复: {result1['reply'][:100]}...")
|
||||
|
||||
# 等待一下,让 mem0 保存
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# 在第二个线程中询问(不同的 thread_id)
|
||||
print("\n[线程 2] 发送消息: '我的宠物叫什么名字?是什么动物?'")
|
||||
result2 = await agent_service.process_message(
|
||||
"我的宠物叫什么名字?是什么动物?", thread_id2, "local", user_id
|
||||
)
|
||||
print(f"回复: {result2['reply']}")
|
||||
|
||||
# 验证长期记忆
|
||||
if "小白" in result2['reply'] or "猫" in result2['reply']:
|
||||
print("\n✅ 长期记忆测试通过!")
|
||||
return True
|
||||
else:
|
||||
print("\n⚠️ 长期记忆可能未启用,或需要手动验证")
|
||||
return None
|
||||
|
||||
async def main():
|
||||
"""主测试函数"""
|
||||
print("\n" + "=" * 70)
|
||||
print(" 后端完整功能测试")
|
||||
print("=" * 70)
|
||||
|
||||
results = {}
|
||||
|
||||
try:
|
||||
# 创建数据库连接和服务
|
||||
print("\n正在初始化数据库连接...")
|
||||
async with AsyncPostgresSaver.from_conn_string(DB_URI) as checkpointer:
|
||||
await checkpointer.setup()
|
||||
print("✅ 数据库连接成功")
|
||||
|
||||
# 创建服务实例
|
||||
print("\n正在初始化 Agent 服务...")
|
||||
agent_service = AIAgentService(checkpointer)
|
||||
await agent_service.initialize()
|
||||
print("✅ Agent 服务初始化成功")
|
||||
|
||||
history_service = ThreadHistoryService(checkpointer)
|
||||
print("✅ 历史服务初始化成功")
|
||||
|
||||
print(f"\n可用模型: {list(agent_service.graphs.keys())}")
|
||||
|
||||
# 运行测试
|
||||
results["短期记忆"] = await test_short_term_memory(agent_service)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
results["工具调用"] = await test_tool_calling(agent_service)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
results["流式对话"] = await test_streaming(agent_service)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
results["历史查询"] = await test_history_service(agent_service, history_service)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
results["长期记忆"] = await test_long_term_memory(agent_service)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# 打印总结
|
||||
await print_section("测试总结")
|
||||
print("\n测试结果:")
|
||||
print("-" * 40)
|
||||
|
||||
pass_count = 0
|
||||
fail_count = 0
|
||||
skip_count = 0
|
||||
|
||||
for test_name, result in results.items():
|
||||
if result is True:
|
||||
status = "✅ 通过"
|
||||
pass_count += 1
|
||||
elif result is False:
|
||||
status = "❌ 失败"
|
||||
fail_count += 1
|
||||
else:
|
||||
status = "⚠️ 待验证"
|
||||
skip_count += 1
|
||||
print(f" {test_name:12s}: {status}")
|
||||
|
||||
print("-" * 40)
|
||||
print(f"总计: {len(results)} 个测试")
|
||||
print(f"通过: {pass_count}, 失败: {fail_count}, 待验证: {skip_count}")
|
||||
|
||||
if fail_count == 0:
|
||||
print("\n🎉 所有核心测试通过!")
|
||||
else:
|
||||
print(f"\n⚠️ 有 {fail_count} 个测试失败")
|
||||
|
||||
except Exception as e:
|
||||
error(f"\n❌ 测试运行异常: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return 1
|
||||
|
||||
return 0 if fail_count == 0 else 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit_code = asyncio.run(main())
|
||||
sys.exit(exit_code)
|
||||
@@ -3,7 +3,7 @@ AI Agent 前端模块
|
||||
采用分层架构设计,包含配置、状态、API客户端和UI组件
|
||||
"""
|
||||
|
||||
from .logger import debug, info, warning, error
|
||||
from frontend.logger import debug, info, warning, error
|
||||
|
||||
__version__ = "2.0.0"
|
||||
__all__ = ["debug", "info", "warning", "error"]
|
||||
@@ -9,8 +9,6 @@ from datetime import datetime
|
||||
# 使用绝对导入
|
||||
from frontend.state import AppState
|
||||
from frontend.api_client import api_client
|
||||
from frontend.config import config
|
||||
|
||||
|
||||
def render_sidebar():
|
||||
"""渲染左侧栏"""
|
||||
@@ -25,7 +23,6 @@ def render_sidebar():
|
||||
st.divider()
|
||||
_render_user_section()
|
||||
|
||||
|
||||
def _render_user_section():
|
||||
"""渲染用户登录区域"""
|
||||
# st.header("👤 用户") # 移除显眼的标题,改用更柔和的 caption
|
||||
@@ -36,7 +33,6 @@ def _render_user_section():
|
||||
else:
|
||||
_render_user_info()
|
||||
|
||||
|
||||
def _render_login_form():
|
||||
"""渲染登录表单"""
|
||||
username = st.text_input(
|
||||
@@ -54,7 +50,6 @@ def _render_login_form():
|
||||
|
||||
# st.info("💡 建议登录以隔离对话历史") # 移除多余色块
|
||||
|
||||
|
||||
def _render_user_info():
|
||||
"""渲染用户信息"""
|
||||
st.markdown(f"**当前用户**: `{AppState.get_user_id()}`")
|
||||
@@ -64,7 +59,6 @@ def _render_user_info():
|
||||
_refresh_threads()
|
||||
st.rerun()
|
||||
|
||||
|
||||
def _render_history_section():
|
||||
"""渲染历史对话列表"""
|
||||
col1, col2 = st.columns([3, 1])
|
||||
@@ -76,7 +70,6 @@ def _render_history_section():
|
||||
|
||||
_render_thread_list()
|
||||
|
||||
|
||||
def _render_history_actions():
|
||||
"""渲染历史操作按钮"""
|
||||
# 移除了 type="primary",让它变成普通的线框按钮,不再是大红块
|
||||
@@ -84,7 +77,6 @@ def _render_history_actions():
|
||||
AppState.start_new_thread()
|
||||
st.rerun()
|
||||
|
||||
|
||||
def _render_thread_list():
|
||||
"""渲染线程列表"""
|
||||
# 仅在初次加载时拉取,或由外部主动调用 _refresh_threads() 更新
|
||||
@@ -101,7 +93,6 @@ def _render_thread_list():
|
||||
for thread in threads:
|
||||
_render_thread_item(thread)
|
||||
|
||||
|
||||
def _render_thread_item(thread: dict):
|
||||
"""
|
||||
渲染单个线程项
|
||||
@@ -130,7 +121,6 @@ def _render_thread_item(thread: dict):
|
||||
):
|
||||
_load_thread(thread_id)
|
||||
|
||||
|
||||
def _format_time(time_str: str) -> str:
|
||||
"""
|
||||
格式化时间字符串
|
||||
@@ -150,13 +140,11 @@ def _format_time(time_str: str) -> str:
|
||||
except Exception:
|
||||
return time_str[:10]
|
||||
|
||||
|
||||
def _refresh_threads():
|
||||
"""刷新历史线程列表"""
|
||||
threads = api_client.get_user_threads(AppState.get_user_id())
|
||||
AppState.set_threads(threads)
|
||||
|
||||
|
||||
def _load_thread(thread_id: str):
|
||||
"""
|
||||
加载指定线程的消息历史
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 加载 .env 文件
|
||||
@@ -25,7 +26,7 @@ class FrontendConfig:
|
||||
|
||||
# ==================== 模型配置 ====================
|
||||
default_model: str = "local" # 更改为local作为默认模型
|
||||
model_options: dict = None
|
||||
model_options: Optional[dict] = None
|
||||
|
||||
# ==================== 用户配置 ====================
|
||||
default_user_id: str = "default_user"
|
||||
@@ -53,7 +54,7 @@ class FrontendConfig:
|
||||
"""从环境变量加载配置(优先级最高)"""
|
||||
# API 地址(移除 /chat 后缀)
|
||||
# 优先级:环境变量 API_URL > 默认值
|
||||
api_url = os.getenv("API_URL", "http://127.0.0.1:8083")
|
||||
api_url = os.getenv("API_URL", "http://127.0.0.1:8079")
|
||||
self.api_base = api_url.replace("/chat", "").rstrip("/")
|
||||
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import uuid
|
||||
from typing import List, Dict, Any
|
||||
import streamlit as st
|
||||
|
||||
from .config import config
|
||||
from frontend.config import config
|
||||
|
||||
|
||||
class AppState:
|
||||
|
||||
@@ -4,10 +4,10 @@ RAG Core - 公共 RAG 组件包
|
||||
提供嵌入模型、向量存储和文档存储的公共功能,被 rag_indexer 和 app/rag 共用。
|
||||
"""
|
||||
|
||||
from .embedders import LlamaCppEmbedder
|
||||
from .vector_store import QdrantVectorStore, QDRANT_URL, QDRANT_API_KEY
|
||||
from .store import PostgresDocStore, create_docstore
|
||||
from .retriever_factory import create_parent_retriever
|
||||
from rag_core.embedders import LlamaCppEmbedder
|
||||
from rag_core.vector_store import QdrantVectorStore, QDRANT_URL, QDRANT_API_KEY
|
||||
from rag_core.store import PostgresDocStore, create_docstore
|
||||
from rag_core.retriever_factory import create_parent_retriever
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
||||
@@ -3,22 +3,26 @@ import os
|
||||
from typing import Optional
|
||||
from qdrant_client import QdrantClient
|
||||
|
||||
|
||||
QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333")
|
||||
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
|
||||
|
||||
def create_qdrant_client(
|
||||
url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
timeout: int = 120, # 索引构建需要较长超时
|
||||
timeout: int = 300, # 索引构建需要较长超时
|
||||
) -> QdrantClient:
|
||||
effective_url = url or QDRANT_URL
|
||||
effective_api_key = api_key or QDRANT_API_KEY
|
||||
|
||||
|
||||
if not effective_url:
|
||||
raise ValueError("Qdrant URL 未配置")
|
||||
|
||||
client_kwargs = {"url": effective_url, "timeout": timeout}
|
||||
|
||||
client_kwargs = {
|
||||
"url": effective_url,
|
||||
"timeout": timeout,
|
||||
}
|
||||
if effective_api_key:
|
||||
client_kwargs["api_key"] = effective_api_key
|
||||
|
||||
|
||||
return QdrantClient(**client_kwargs)
|
||||
@@ -5,11 +5,9 @@
|
||||
import os
|
||||
import httpx
|
||||
from typing import List, Optional
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
|
||||
class LlamaCppEmbedder:
|
||||
"""通过 OpenAI 兼容 API 封装 llama.cpp 嵌入服务。"""
|
||||
|
||||
@@ -17,7 +15,7 @@ class LlamaCppEmbedder:
|
||||
self,
|
||||
base_url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
model: str = "embeddinggemma-300M-Q8_0",
|
||||
model: str = "Qwen3-Embedding-0.6B-Q8_0",
|
||||
):
|
||||
self.base_url = base_url or os.getenv("LLAMACPP_EMBEDDING_URL", "http://127.0.0.1:8082")
|
||||
self.api_key = api_key or os.getenv("LLAMACPP_API_KEY", "")
|
||||
@@ -71,7 +69,6 @@ class LlamaCppEmbedder:
|
||||
else:
|
||||
raise ValueError(f"未知的嵌入 API 响应格式: {data}")
|
||||
|
||||
|
||||
class _LlamaCppLangchainAdapter(Embeddings):
|
||||
"""将 LlamaCppEmbedder 适配为 LangChain Embeddings 接口。"""
|
||||
|
||||
|
||||
@@ -2,14 +2,7 @@
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_classic.retrievers import ParentDocumentRetriever
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
from rag_indexer.splitters import SplitterType, get_splitter
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import List, Union, Optional, Any, Dict, Tuple
|
||||
from httpx import RemoteProtocolError
|
||||
from langchain_core.documents import Document
|
||||
from typing import Optional
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.stores import BaseStore
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
|
||||
@@ -17,7 +10,6 @@ from langchain_classic.retrievers import ParentDocumentRetriever
|
||||
|
||||
from rag_core import LlamaCppEmbedder, QdrantVectorStore, create_docstore
|
||||
|
||||
|
||||
def create_parent_retriever(
|
||||
collection_name: str = "rag_documents",
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
|
||||
@@ -15,8 +15,8 @@
|
||||
"""
|
||||
|
||||
|
||||
from .postgres import PostgresDocStore
|
||||
from .factory import create_docstore, get_docstore_uri, DEFAULT_DB_URI
|
||||
from rag_core.store.postgres import PostgresDocStore
|
||||
from rag_core.store.factory import create_docstore, get_docstore_uri, DEFAULT_DB_URI
|
||||
|
||||
__version__ = "2.0.0"
|
||||
|
||||
|
||||
@@ -9,9 +9,11 @@ import logging
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from langchain_core.stores import BaseStore
|
||||
from .postgres import PostgresDocStore
|
||||
from rag_core.store.postgres import PostgresDocStore
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 默认连接字符串(从环境变量读取)
|
||||
DEFAULT_DB_URI = os.getenv(
|
||||
|
||||
@@ -4,12 +4,10 @@
|
||||
使用 asyncpg 实现真正的异步 PostgreSQL 文档存储,支持高并发访问。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional, Iterator, Tuple, Sequence, cast
|
||||
from typing import List, Dict, Any, Optional, Iterator, Tuple, Sequence
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.stores import BaseStore
|
||||
@@ -18,7 +16,6 @@ import asyncpg
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PostgresDocStore(BaseStore[str, Any]):
|
||||
"""
|
||||
异步 PostgreSQL 文档存储实现。
|
||||
|
||||
@@ -4,13 +4,16 @@ Qdrant 向量数据库包装器。
|
||||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import List, Optional, Dict, Any
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_qdrant import QdrantVectorStore as LangchainQdrantVS
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.http.models import Distance, VectorParams
|
||||
from .client import create_qdrant_client
|
||||
from httpx import RemoteProtocolError
|
||||
from qdrant_client.http.exceptions import ResponseHandlingException
|
||||
from rag_core.client import create_qdrant_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -28,9 +31,11 @@ class QdrantVectorStore:
|
||||
):
|
||||
self.collection_name = collection_name
|
||||
self._client: Optional[QdrantClient] = None
|
||||
self._connection_attempts = 0
|
||||
self._last_connection_time: Optional[float] = None
|
||||
|
||||
if embeddings is None:
|
||||
from .embedders import LlamaCppEmbedder
|
||||
from rag_core.embedders import LlamaCppEmbedder
|
||||
embedder = LlamaCppEmbedder()
|
||||
self.embeddings = embedder.as_langchain_embeddings()
|
||||
else:
|
||||
@@ -46,38 +51,89 @@ class QdrantVectorStore:
|
||||
|
||||
def get_client(self) -> QdrantClient:
|
||||
if self._client is None:
|
||||
self._client = create_qdrant_client(timeout=120)
|
||||
self._client = create_qdrant_client(timeout=300)
|
||||
self._connection_attempts += 1
|
||||
self._last_connection_time = time.time()
|
||||
logger.debug("Qdrant 客户端已创建 (第 %d 次连接)", self._connection_attempts)
|
||||
return self._client
|
||||
|
||||
def refresh_client(self):
|
||||
"""关闭旧连接,创建新连接。"""
|
||||
if self._client is not None:
|
||||
self._client.close()
|
||||
self._client = None
|
||||
try:
|
||||
self._client.close()
|
||||
logger.debug("Qdrant 旧连接已关闭")
|
||||
except Exception as e:
|
||||
logger.warning("关闭 Qdrant 连接时出现异常: %s", e)
|
||||
finally:
|
||||
self._client = None
|
||||
self._last_connection_time = None
|
||||
|
||||
def check_connection_health(self) -> bool:
|
||||
"""检查连接健康状态,如果连接已失效则自动重建。"""
|
||||
if self._client is None:
|
||||
logger.info("Qdrant 客户端未初始化,将创建新连接")
|
||||
return False
|
||||
|
||||
try:
|
||||
client = self.get_client()
|
||||
client.get_collections()
|
||||
logger.debug("Qdrant 连接健康检查通过")
|
||||
return True
|
||||
except (RemoteProtocolError, ConnectionError, OSError, ResponseHandlingException) as e:
|
||||
logger.warning("Qdrant 连接健康检查失败: %s", e)
|
||||
self.refresh_client()
|
||||
return False
|
||||
|
||||
def get_connection_stats(self) -> Dict[str, Any]:
|
||||
"""获取连接统计信息。"""
|
||||
return {
|
||||
"connection_attempts": self._connection_attempts,
|
||||
"last_connection_time": self._last_connection_time,
|
||||
"client_initialized": self._client is not None,
|
||||
}
|
||||
|
||||
def create_collection(self, vector_size: Optional[int] = None, force_recreate: bool = False):
|
||||
"""创建集合,设置合适的向量维度。"""
|
||||
if vector_size is None:
|
||||
from .embedders import LlamaCppEmbedder
|
||||
from rag_core.embedders import LlamaCppEmbedder
|
||||
embedder = LlamaCppEmbedder()
|
||||
vector_size = embedder.get_embedding_dimension()
|
||||
|
||||
client = self.get_client()
|
||||
collections = client.get_collections().collections
|
||||
exists = any(c.name == self.collection_name for c in collections)
|
||||
max_retries = 3
|
||||
base_delay = 2
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
client = self.get_client()
|
||||
collections = client.get_collections().collections
|
||||
exists = any(c.name == self.collection_name for c in collections)
|
||||
|
||||
if exists and force_recreate:
|
||||
client.delete_collection(self.collection_name)
|
||||
exists = False
|
||||
if exists and force_recreate:
|
||||
client.delete_collection(self.collection_name)
|
||||
exists = False
|
||||
|
||||
if not exists:
|
||||
client.create_collection(
|
||||
collection_name=self.collection_name,
|
||||
vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE),
|
||||
)
|
||||
logger.info("集合 '%s' 已创建(维度=%d)", self.collection_name, vector_size)
|
||||
else:
|
||||
logger.info("集合 '%s' 已存在", self.collection_name)
|
||||
if not exists:
|
||||
client.create_collection(
|
||||
collection_name=self.collection_name,
|
||||
vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE),
|
||||
)
|
||||
logger.info("集合 '%s' 已创建(维度=%d)", self.collection_name, vector_size)
|
||||
else:
|
||||
logger.info("集合 '%s' 已存在", self.collection_name)
|
||||
return
|
||||
except (RemoteProtocolError, ConnectionError, OSError, ResponseHandlingException) as e:
|
||||
if attempt == max_retries - 1:
|
||||
logger.error("创建集合 '%s' 重试 %d 次后仍然失败: %s", self.collection_name, max_retries, e)
|
||||
raise
|
||||
wait_time = base_delay * (2 ** attempt)
|
||||
error_type = type(e).__name__
|
||||
logger.warning(
|
||||
"创建集合 '%s' 遇到网络异常 [%s],%d秒后重试 (%d/%d): %s",
|
||||
self.collection_name, error_type, wait_time, attempt + 1, max_retries, e
|
||||
)
|
||||
self.refresh_client()
|
||||
logger.debug("已刷新 Qdrant 客户端连接")
|
||||
time.sleep(wait_time)
|
||||
|
||||
def add_documents(self, documents: List[Document], batch_size: int = 100):
|
||||
"""将文档添加到向量数据库。"""
|
||||
@@ -102,9 +158,10 @@ class QdrantVectorStore:
|
||||
info = self.get_client().get_collection(self.collection_name)
|
||||
vectors_config = info.config.params.vectors
|
||||
if isinstance(vectors_config, dict):
|
||||
vector_size = next(iter(vectors_config.values())).size
|
||||
first_config = next(iter(vectors_config.values()), None)
|
||||
vector_size = first_config.size if first_config else 0
|
||||
else:
|
||||
vector_size = vectors_config.size
|
||||
vector_size = vectors_config.size if vectors_config else 0
|
||||
return {
|
||||
"name": self.collection_name,
|
||||
"vectors_count": info.points_count or 0,
|
||||
|
||||
@@ -23,9 +23,9 @@ Offline RAG Indexer module.
|
||||
>>> await builder.build_from_file("document.pdf")
|
||||
"""
|
||||
|
||||
from .index_builder import IndexBuilder, IndexBuilderConfig, DocstoreConfig
|
||||
from .loaders import DocumentLoader
|
||||
from .splitters import SplitterType, get_splitter
|
||||
from rag_indexer.index_builder import IndexBuilder, IndexBuilderConfig, DocstoreConfig
|
||||
from rag_indexer.loaders import DocumentLoader
|
||||
from rag_indexer.splitters import SplitterType, get_splitter
|
||||
|
||||
# 从 rag_core 重新导出常用组件
|
||||
from rag_core import (
|
||||
|
||||
@@ -8,23 +8,21 @@ import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import List, Union, Optional, Any, Dict, Tuple
|
||||
from typing import List, Union, Optional, Any, Dict
|
||||
|
||||
from httpx import RemoteProtocolError
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.stores import BaseStore
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
|
||||
from langchain_classic.retrievers import ParentDocumentRetriever
|
||||
from qdrant_client.http.exceptions import ResponseHandlingException
|
||||
|
||||
from .loaders import DocumentLoader
|
||||
from .splitters import SplitterType, get_splitter, SemanticChunkerAdapter
|
||||
from rag_indexer.loaders import DocumentLoader
|
||||
from rag_indexer.splitters import SplitterType, get_splitter
|
||||
from rag_core import LlamaCppEmbedder, QdrantVectorStore, create_docstore, create_parent_retriever
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------- 配置数据类 ----------
|
||||
@dataclass
|
||||
class DocstoreConfig:
|
||||
@@ -35,7 +33,6 @@ class DocstoreConfig:
|
||||
# 若要从外部注入已创建好的 docstore,可直接设置此字段
|
||||
instance: Optional[BaseStore] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexBuilderConfig:
|
||||
"""索引构建器配置。"""
|
||||
@@ -59,7 +56,6 @@ class IndexBuilderConfig:
|
||||
# 其他切分器参数(当 splitter_type 非父子块时使用)
|
||||
extra_splitter_kwargs: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
# ---------- 索引构建器 ----------
|
||||
class IndexBuilder:
|
||||
"""RAG 索引构建主流水线,支持单块切分与父子块切分。"""
|
||||
@@ -223,18 +219,26 @@ class IndexBuilder:
|
||||
|
||||
async def _add_batch_with_retry(self, batch: List[Document], batch_no: int) -> None:
|
||||
"""添加批次,失败时自动重试(处理网络波动)。"""
|
||||
max_retries = 3
|
||||
max_retries = 5
|
||||
base_delay = 2
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
await self.retriever.aadd_documents(batch) # type: ignore[union-attr]
|
||||
logger.info("批次 %d 成功添加 %d 个文档", batch_no, len(batch))
|
||||
return
|
||||
except (RemoteProtocolError, ConnectionError, OSError) as e:
|
||||
except (RemoteProtocolError, ConnectionError, OSError, ResponseHandlingException) as e:
|
||||
if attempt == max_retries - 1:
|
||||
logger.error("批次 %d 重试 %d 次后仍然失败: %s", batch_no, max_retries, e)
|
||||
raise
|
||||
logger.warning("批次 %d 连接断开,重试 (%d/%d): %s",
|
||||
batch_no, attempt + 1, max_retries, e)
|
||||
wait_time = base_delay * (2 ** attempt)
|
||||
error_type = type(e).__name__
|
||||
logger.warning(
|
||||
"批次 %d 遇到网络异常 [%s],%d秒后重试 (%d/%d): %s",
|
||||
batch_no, error_type, wait_time, attempt + 1, max_retries, e
|
||||
)
|
||||
self.vector_store.refresh_client()
|
||||
await asyncio.sleep(1)
|
||||
logger.debug("批次 %d 已刷新 Qdrant 客户端连接", batch_no)
|
||||
await asyncio.sleep(wait_time)
|
||||
|
||||
# ---------- 信息获取方法 ----------
|
||||
def get_collection_info(self) -> Any:
|
||||
|
||||
@@ -250,7 +250,7 @@ start_embedding() {
|
||||
echo -e "${BLUE}🚀 启动 llama.cpp Embedding 容器...${NC}"
|
||||
|
||||
# 检查模型文件
|
||||
if [ ! -f "/home/huang/Study/AIModel/GGUF/embeddinggemma-300M-Q8_0.gguf" ]; then
|
||||
if [ ! -f "/home/huang/Study/AIModel/GGUF/Qwen3-Embedding-0.6B-Q8_0.gguf" ]; then
|
||||
echo -e "${RED}✗ 错误:Embedding 模型文件不存在${NC}"
|
||||
exit 1
|
||||
fi
|
||||
@@ -263,13 +263,16 @@ start_embedding() {
|
||||
--device=/dev/dri \
|
||||
-v /home/huang/Study/AIModel/GGUF:/models \
|
||||
-p 8082:8080 \
|
||||
-e LLAMA_ARG_CTX_SIZE=16384 \
|
||||
-e LLAMA_ARG_N_PARALLEL=1 \
|
||||
-e LLAMA_ARG_BATCH=512 \
|
||||
-e LLAMA_ARG_N_GPU_LAYERS=99 \
|
||||
-e LLAMA_ARG_API_KEY=huang1998 \
|
||||
ghcr.io/ggml-org/llama.cpp:server-rocm \
|
||||
-m /models/embeddinggemma-300M-Q8_0.gguf \
|
||||
-m /models/Qwen3-Embedding-0.6B-Q8_0.gguf \
|
||||
--host 0.0.0.0 \
|
||||
--port 8080 \
|
||||
-ngl 99 \
|
||||
--embeddings \
|
||||
-c 512
|
||||
--embeddings
|
||||
|
||||
echo -e "${GREEN}✓ llama.cpp Embedding 容器已启动 (端口 8082)${NC}"
|
||||
sleep 5
|
||||
@@ -288,7 +291,7 @@ start_backend() {
|
||||
set +a
|
||||
|
||||
export PYTHONPATH="$PROJECT_DIR"
|
||||
export BACKEND_PORT=8083
|
||||
export BACKEND_PORT=8079
|
||||
python app/backend.py &
|
||||
BACKEND_PID=$!
|
||||
echo -e "${GREEN}✓ 后端服务已启动 (PID: $BACKEND_PID)${NC}"
|
||||
|
||||
Reference in New Issue
Block a user