修改rag,实现混合检索
All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 5m42s

This commit is contained in:
2026-05-04 04:28:32 +08:00
parent d0590240f9
commit 82dde7113e
15 changed files with 536 additions and 65 deletions

View File

@@ -70,7 +70,7 @@ class AIAgentService:
raise RuntimeError("没有可用的模型")
return self
async def process_message(self, message: str, thread_id: str, model: str = "zhipu", user_id: str = "default_user") -> dict:
async def process_message(self, message: str, thread_id: str, model: str = "local", user_id: str = "default_user") -> dict:
"""处理用户消息返回包含回复、token统计和耗时的字典"""
if model not in self.graphs:
# 回退到第一个可用模型
@@ -175,6 +175,8 @@ class AIAgentService:
try:
info(f"📡 开始调用 graph.astream()...")
chunk_count = 0
full_message_content = "" # 收集完整消息内容
async for chunk in graph.astream(
input_state,
config=config,
@@ -184,21 +186,11 @@ class AIAgentService:
):
chunk_count += 1
chunk_type = chunk["type"]
info(f"📦 收到第 {chunk_count} 个chunk, type: {chunk_type}")
processed_event = {}
if chunk_type == "messages":
message_chunk, metadata = chunk["data"]
node_name = metadata.get("langgraph_node", "unknown")
info(f"📨 处理消息chunk, node: {node_name}")
# 详细记录消息内容,看看这些 chunk 到底是什么
if hasattr(message_chunk, "content"):
content_preview = str(message_chunk.content)[:200]
info(f"📄 消息内容预览 ({len(content_preview)} chars): {repr(content_preview)}")
if hasattr(message_chunk, "type"):
info(f"📋 消息类型: {message_chunk.type}")
if hasattr(message_chunk, "tool_calls"):
info(f"🔧 包含工具调用: {message_chunk.tool_calls}")
# 检测节点变化,发送节点开始事件
if node_name != current_node:
@@ -218,8 +210,6 @@ class AIAgentService:
reasoning_token = ""
if hasattr(message_chunk, 'additional_kwargs'):
reasoning_token = message_chunk.additional_kwargs.get("reasoning_content", "")
info(f"💬 消息token: token_content='{repr(token_content[:50])}', reasoning_token='{repr(reasoning_token[:50])}', node_name='{node_name}'")
# 处理思考过程
if reasoning_token:
@@ -228,7 +218,6 @@ class AIAgentService:
"node": node_name,
"reasoning_token": reasoning_token
}
info(f"✅ 生成 reasoning_token 事件: {processed_event}")
# 处理工具调用
elif hasattr(message_chunk, 'tool_calls') and message_chunk.tool_calls:
for tool_call in message_chunk.tool_calls:
@@ -248,7 +237,7 @@ class AIAgentService:
"args": tool_args,
"id": tool_call_id
}
# 处理普通 token
# 处理普通 token - 只收集,不打印单个 token
elif token_content:
processed_event = {
"type": "llm_token",
@@ -256,18 +245,13 @@ class AIAgentService:
"token": token_content,
"reasoning_token": reasoning_token
}
info(f"✅ 生成 llm_token 事件: {processed_event}")
else:
info(f"⚠️ 没有生成任何事件token_content='{repr(token_content)}', reasoning_token='{repr(reasoning_token)}'")
if node_name == "llm_call":
full_message_content += token_content
elif chunk_type == "updates":
info(f"🔄 处理updates chunk")
updates_data = chunk["data"]
serialized_data = self._serialize_value(updates_data)
# 关键修复:不再从 updates 中读取 latest_reasoning避免重复
# 因为我们现在直接通过 custom 事件发送推理结果了
# 检查是否有人工审核请求
if "review_pending" in serialized_data and serialized_data["review_pending"]:
review_id = serialized_data.get("review_id", "")
@@ -302,18 +286,12 @@ class AIAgentService:
}
elif chunk_type == "custom":
info(f"🎯 处理custom chunk, 完整数据: {repr(chunk)}")
custom_data = chunk["data"]
info(f"🎯 custom_data 内容: {repr(custom_data)}")
info(f"🎯 custom_data 类型: {type(custom_data)}")
# 关键修复:处理我们从 react_reason_node 发送的自定义推理事件
# LangGraph 的 adispatch_custom_event 发送的事件格式:
# chunk["data"] 是我们传的第二个参数dict
# 处理我们从 react_reason_node 发送的自定义推理事件
if isinstance(custom_data, dict):
# 检查是否是我们的推理事件
if "action" in custom_data and "reasoning" in custom_data:
info(f"[Agent Service] 收到自定义推理事件: {custom_data}")
yield {
"type": "react_reasoning",
"step": custom_data.get("step", 1),
@@ -339,7 +317,10 @@ class AIAgentService:
if processed_event:
yield processed_event
# 完整消息集合完成后,一次性打印
info(f"✅ graph.astream() 完成,共 {chunk_count} 个chunks")
if full_message_content:
info(f"📄 完整消息内容: {repr(full_message_content)}")
except Exception as e:
error(f"❌ 执行 React 图时出错: {e}")

View File

@@ -12,18 +12,21 @@ class ThreadHistoryService:
def __init__(self, checkpointer):
self.checkpointer = checkpointer
async def get_user_threads(self, user_id: str, limit: int = 50) -> List[Dict[str, Any]]:
async def get_user_threads(self, user_id: str, limit: int = 4) -> List[Dict[str, Any]]:
"""
获取指定用户的所有线程摘要信息
Args:
user_id: 用户 ID
limit: 返回数量限制
limit: 返回数量限制强制最多4条
Returns:
线程列表,每个包含 thread_id, last_updated, summary, message_count
"""
try:
# 强制限制最多4条
actual_limit = min(limit, 4)
# 查询 checkpoints 表获取用户的线程列表
async with self.checkpointer.conn.cursor() as cur:
# 在较新的 LangGraph 版本中AsyncPostgresSaver 创建的 checkpoints 表
@@ -40,7 +43,7 @@ class ThreadHistoryService:
ORDER BY last_updated DESC
LIMIT %s
"""
await cur.execute(query, (user_id, limit))
await cur.execute(query, (user_id, actual_limit))
rows = await cur.fetchall()
threads = []

View File

@@ -98,7 +98,7 @@ async def health_check():
class ChatRequest(BaseModel):
message: str
thread_id: str | None = None
model: str = "zhipu"
model: str = "local"
user_id: str = "default_user"
class ChatResponse(BaseModel):
@@ -212,7 +212,7 @@ async def chat_endpoint(
@app.get("/threads")
async def list_threads(
user_id: str = Query("default_user", description="用户 ID"),
limit: int = Query(50, ge=1, le=200, description="返回数量限制"),
limit: int = Query(4, ge=1, le=200, description="返回数量限制"),
history_service: ThreadHistoryService = Depends(get_history_service)
):
"""获取当前用户的对话历史列表"""
@@ -312,7 +312,7 @@ async def websocket_endpoint(
data = await websocket.receive_json()
message = data.get("message")
thread_id = data.get("thread_id", str(uuid.uuid4()))
model = data.get("model", "zhipu")
model = data.get("model", "local")
user_id = data.get("user_id", "default_user")
if not message:
await websocket.send_json({"error": "missing message"})
@@ -435,4 +435,10 @@ if __name__ == "__main__":
import uvicorn
# 使用环境变量或默认端口 8079避免与 llama.cpp 的 8081 端口冲突)
port = int(BACKEND_PORT)
uvicorn.run(app, host="0.0.0.0", port=port)
uvicorn.run(
app,
host="0.0.0.0",
port=port,
log_level="debug",
access_log=True
)

View File

@@ -130,12 +130,16 @@ class ReactIntentReasoner:
retrieved_docs = context.get("retrieved_docs", [])
messages = context.get("messages", [])
# 关键修复 2如果已经有 rag_context 或 web_search_results通过 messages 推断),直接回答
# 检查是否已经执行过 rag_retrieve 或 web_search
if "rag_retrieve" in previous_actions or "web_search" in previous_actions:
# 关键修改:不要在第一次 rag_retrieve 后就直接回答,允许再推理一次
# 让推理逻辑有机会判断 RAG 结果好不好,要不要再检索或转 web search
rag_count = previous_actions.count("rag_retrieve")
web_search_count = previous_actions.count("web_search")
# 只有当 rag 或 web search 已经超过 1 次,或者已经有推理在 rag 之后,才直接回答
if rag_count >= 2 or web_search_count >= 1:
result.action = ReasoningAction.DIRECT_RESPONSE
result.confidence = 0.95
result.reasoning = "已获取信息,直接回答"
result.reasoning = "已获取足够信息,直接回答"
return result
# 策略1尝试使用 LLM 推理

View File

@@ -95,10 +95,10 @@ def inject_rag_tool_to_state(state: MainGraphState, rag_tool: Any) -> MainGraphS
return state
# ========== RAG 检索核心逻辑(真正利用已有代码)==========
def _rag_retrieve_core(state: MainGraphState) -> MainGraphState:
# ========== RAG 检索核心逻辑(真正利用已有代码) ==========
async def _rag_retrieve_core(state: MainGraphState) -> MainGraphState:
"""
RAG 检索核心逻辑(真正利用 rag/tools.py
RAG 检索核心逻辑(真正利用 rag/tools.py - 异步版本
Args:
state: 主图状态
@@ -119,10 +119,10 @@ def _rag_retrieve_core(state: MainGraphState) -> MainGraphState:
rag_tool = get_rag_tool_from_state(state)
if rag_tool:
# 使用真正的 RAG 工具(来自 rag/tools.py
# 使用真正的 RAG 工具(来自 rag/tools.py- 异步版本
try:
# 调用 LangChain Tool 的 invoke 方法
rag_context = rag_tool.invoke(retrieval_query)
# 直接 await 异步工具ainvoke 方法
rag_context = await rag_tool.ainvoke(retrieval_query)
state.rag_context = rag_context
state.rag_docs = [
{"source": "rag_retrieval", "content": rag_context}
@@ -134,9 +134,9 @@ def _rag_retrieve_core(state: MainGraphState) -> MainGraphState:
except Exception as e:
raise RuntimeError(f"RAG 工具调用失败: {str(e)}") from e
elif _GLOBAL_RAG_PIPELINE:
# 使用 RAG Pipeline 直接检索
# 使用 RAG Pipeline 直接检索 - 直接用异步方法
try:
documents = _GLOBAL_RAG_PIPELINE.retrieve(retrieval_query)
documents = await _GLOBAL_RAG_PIPELINE.aretrieve(retrieval_query)
if documents:
rag_context = _GLOBAL_RAG_PIPELINE.format_context(documents)
state.rag_context = rag_context
@@ -158,7 +158,7 @@ def _rag_retrieve_core(state: MainGraphState) -> MainGraphState:
raise RuntimeError("RAG 工具未初始化,请先调用 set_global_rag_tool() 或 set_global_rag_pipeline()")
# ========== RAG 检索节点(带超时和重试)==========
# ========== RAG 检索节点(带超时和重试) ==========
async def rag_retrieve_node(state: MainGraphState, config: Optional[Dict[str, Any]] = None) -> MainGraphState:
"""
RAG 检索节点:带超时和重试,真正利用已有 RAG 代码
@@ -196,8 +196,13 @@ async def rag_retrieve_node(state: MainGraphState, config: Optional[Dict[str, An
for attempt in range(RAG_RETRY_CONFIG.max_retries + 1):
try:
# 执行核心逻辑
result = _rag_retrieve_core(state)
# 执行核心逻辑 - 异步 await
result = await _rag_retrieve_core(state)
info(f"[rag_retrieve_node] RAG 检索成功,获取到上下文长度: {len(result.rag_context)} 字符")
if result.rag_docs:
for i, doc in enumerate(result.rag_docs[:3]): # 只显示前3条
info(f"[rag_retrieve_node] 文档 {i+1}: {doc.get('content', '')[:100]}...")
# 成功
state.debug_info["rag_retrieval"] = {
@@ -226,6 +231,15 @@ async def rag_retrieve_node(state: MainGraphState, config: Optional[Dict[str, An
except Exception as e:
info(f"[rag_retrieve_node] 无法发送完成事件: {e}")
# 关键修复:把 rag_retrieve 加到 reasoning_history 里,让下次推理知道
state.reasoning_history.append({
"step": state.reasoning_step,
"action": "rag_retrieve",
"confidence": 1.0,
"reasoning": "RAG 检索完成",
"timestamp": datetime.now().isoformat()
})
return result
except Exception as e:
@@ -255,7 +269,7 @@ async def rag_retrieve_node(state: MainGraphState, config: Optional[Dict[str, An
# 指数退避等待
delay = RAG_RETRY_CONFIG.base_delay * (2 ** attempt)
time.sleep(min(delay, RAG_RETRY_CONFIG.max_delay))
await asyncio.sleep(min(delay, RAG_RETRY_CONFIG.max_delay))
# 所有重试都失败,记录结构化错误
error_record = ErrorRecord(

View File

@@ -364,20 +364,27 @@ def route_by_reasoning(state: MainGraphState) -> str:
if "subgraph_completed" in previous_actions or state.final_result:
return "llm_call"
# 检查是否刚刚执行 rag 或 web search应该继续推理一次然后去 llm_call
# 但为了避免死循环,我们设置一个简单的规则
if len(previous_actions) > 3:
# 关键修复:如果已经执行 rag_retrieve 并且又执行过推理,直接去 LLM_CALL
# 这样的流程推理1 → RAG → 推理2判断 RAG 结果) → LLM_CALL
rag_count = previous_actions.count("rag_retrieve")
if rag_count >= 1 and len(previous_actions) >= rag_count + 1:
info(f"[route_by_reasoning] 已完成 RAG 检索和结果判断,直接去 llm_call")
return "llm_call"
# 关键修复:限制最多 3 次推理,避免无限循环
if len(previous_actions) >= 3:
info(f"[route_by_reasoning] 已达到最大推理次数 ({len(previous_actions)}),直接去 llm_call")
return "llm_call"
# 获取推理结果
reasoning_result: Optional[ReasoningResult] = state.debug_info.get("reasoning_result")
if not reasoning_result:
return "llm_call"
# 使用 intent.py 提供的路由函数
route = get_route_by_reasoning(reasoning_result)
# 映射到我们的节点名称
# 注意:这些名称必须与 main_graph_builder.py 中定义的节点名称一致
route_mapping = {
@@ -391,7 +398,8 @@ def route_by_reasoning(state: MainGraphState) -> str:
"dictionary": "dictionary_subgraph",
"news_analysis": "news_analysis_subgraph",
}
info(f"[route_by_reasoning] 推理结果={reasoning_result.action.name}, 路由={route_mapping.get(route, 'llm_call')}, 历史动作={previous_actions}")
return route_mapping.get(route, "llm_call")

View File

@@ -1,5 +1,5 @@
# app/rag_initializer.py
from app.rag.tools import create_rag_tool_sync
from app.rag.tools import create_rag_tool_sync, create_rag_tool_async
from rag_core import create_parent_retriever
from app.model_services import get_embedding_service
from app.logger import info, warning
@@ -16,11 +16,11 @@ async def init_rag_tool(local_llm_creator):
embeddings=embeddings
)
rewrite_llm = local_llm_creator()
rag_tool = create_rag_tool_sync(
rag_tool = create_rag_tool_async(
retriever, rewrite_llm,
num_queries=3, rerank_top_n=5
)
info("✅ RAG 检索工具初始化成功")
info("✅ RAG 检索工具初始化成功(异步版本)")
return rag_tool
except Exception as e:
warning(f"⚠️ RAG 检索工具初始化失败: {e}")

View File

@@ -70,6 +70,63 @@ def create_rag_tool_sync(
return search_knowledge_base_sync
def create_rag_tool_async(
retriever: Optional[BaseRetriever] = None,
llm: Optional[BaseLanguageModel] = None,
num_queries: int = 3,
rerank_top_n: int = 5,
collection_name: str = "rag_documents",
) -> Callable:
"""
创建一个配置好的 RAG 检索工具(异步版本)。
默认使用混合检索(稠密+BM25稀疏+ 父子文档模式。
Args:
retriever: 基础检索器对象(可选,不提供则自动创建)
llm: 用于生成多路查询的语言模型(可选)
num_queries: 生成的查询变体数量
rerank_top_n: 最终返回的文档数量
collection_name: Qdrant 集合名称
Returns:
Async LangChain Tool 函数
"""
pipeline = RAGPipeline(
retriever=retriever,
llm=llm,
num_queries=num_queries,
rerank_top_n=rerank_top_n,
collection_name=collection_name,
)
@tool
async def search_knowledge_base_async(query: str) -> str:
"""
在知识库中搜索与查询相关的文档片段(异步版本)。
使用混合检索(稠密向量语义 + BM25 关键词)+ 父子文档模式,
检索效果最优。
Args:
query: 用户提出的问题或查询字符串
Returns:
格式化后的相关文档内容
"""
try:
documents = await pipeline.aretrieve(query)
if not documents:
return f"在知识库 '{collection_name}' 中未找到与 '{query}' 相关的信息。"
context = pipeline.format_context(documents)
return context
except Exception as e:
return f"检索过程中发生错误: {str(e)}"
return search_knowledge_base_async
def create_rag_tool(
collection_name: str = "rag_documents",
llm: Optional[BaseLanguageModel] = None,