This commit is contained in:
@@ -70,7 +70,7 @@ class AIAgentService:
|
||||
raise RuntimeError("没有可用的模型")
|
||||
return self
|
||||
|
||||
async def process_message(self, message: str, thread_id: str, model: str = "zhipu", user_id: str = "default_user") -> dict:
|
||||
async def process_message(self, message: str, thread_id: str, model: str = "local", user_id: str = "default_user") -> dict:
|
||||
"""处理用户消息,返回包含回复、token统计和耗时的字典"""
|
||||
if model not in self.graphs:
|
||||
# 回退到第一个可用模型
|
||||
@@ -175,6 +175,8 @@ class AIAgentService:
|
||||
try:
|
||||
info(f"📡 开始调用 graph.astream()...")
|
||||
chunk_count = 0
|
||||
full_message_content = "" # 收集完整消息内容
|
||||
|
||||
async for chunk in graph.astream(
|
||||
input_state,
|
||||
config=config,
|
||||
@@ -184,21 +186,11 @@ class AIAgentService:
|
||||
):
|
||||
chunk_count += 1
|
||||
chunk_type = chunk["type"]
|
||||
info(f"📦 收到第 {chunk_count} 个chunk, type: {chunk_type}")
|
||||
processed_event = {}
|
||||
|
||||
if chunk_type == "messages":
|
||||
message_chunk, metadata = chunk["data"]
|
||||
node_name = metadata.get("langgraph_node", "unknown")
|
||||
info(f"📨 处理消息chunk, node: {node_name}")
|
||||
# 详细记录消息内容,看看这些 chunk 到底是什么
|
||||
if hasattr(message_chunk, "content"):
|
||||
content_preview = str(message_chunk.content)[:200]
|
||||
info(f"📄 消息内容预览 ({len(content_preview)} chars): {repr(content_preview)}")
|
||||
if hasattr(message_chunk, "type"):
|
||||
info(f"📋 消息类型: {message_chunk.type}")
|
||||
if hasattr(message_chunk, "tool_calls"):
|
||||
info(f"🔧 包含工具调用: {message_chunk.tool_calls}")
|
||||
|
||||
# 检测节点变化,发送节点开始事件
|
||||
if node_name != current_node:
|
||||
@@ -218,8 +210,6 @@ class AIAgentService:
|
||||
reasoning_token = ""
|
||||
if hasattr(message_chunk, 'additional_kwargs'):
|
||||
reasoning_token = message_chunk.additional_kwargs.get("reasoning_content", "")
|
||||
|
||||
info(f"💬 消息token: token_content='{repr(token_content[:50])}', reasoning_token='{repr(reasoning_token[:50])}', node_name='{node_name}'")
|
||||
|
||||
# 处理思考过程
|
||||
if reasoning_token:
|
||||
@@ -228,7 +218,6 @@ class AIAgentService:
|
||||
"node": node_name,
|
||||
"reasoning_token": reasoning_token
|
||||
}
|
||||
info(f"✅ 生成 reasoning_token 事件: {processed_event}")
|
||||
# 处理工具调用
|
||||
elif hasattr(message_chunk, 'tool_calls') and message_chunk.tool_calls:
|
||||
for tool_call in message_chunk.tool_calls:
|
||||
@@ -248,7 +237,7 @@ class AIAgentService:
|
||||
"args": tool_args,
|
||||
"id": tool_call_id
|
||||
}
|
||||
# 处理普通 token
|
||||
# 处理普通 token - 只收集,不打印单个 token
|
||||
elif token_content:
|
||||
processed_event = {
|
||||
"type": "llm_token",
|
||||
@@ -256,18 +245,13 @@ class AIAgentService:
|
||||
"token": token_content,
|
||||
"reasoning_token": reasoning_token
|
||||
}
|
||||
info(f"✅ 生成 llm_token 事件: {processed_event}")
|
||||
else:
|
||||
info(f"⚠️ 没有生成任何事件,token_content='{repr(token_content)}', reasoning_token='{repr(reasoning_token)}'")
|
||||
if node_name == "llm_call":
|
||||
full_message_content += token_content
|
||||
|
||||
elif chunk_type == "updates":
|
||||
info(f"🔄 处理updates chunk")
|
||||
updates_data = chunk["data"]
|
||||
serialized_data = self._serialize_value(updates_data)
|
||||
|
||||
# 关键修复:不再从 updates 中读取 latest_reasoning,避免重复
|
||||
# 因为我们现在直接通过 custom 事件发送推理结果了
|
||||
|
||||
# 检查是否有人工审核请求
|
||||
if "review_pending" in serialized_data and serialized_data["review_pending"]:
|
||||
review_id = serialized_data.get("review_id", "")
|
||||
@@ -302,18 +286,12 @@ class AIAgentService:
|
||||
}
|
||||
|
||||
elif chunk_type == "custom":
|
||||
info(f"🎯 处理custom chunk, 完整数据: {repr(chunk)}")
|
||||
custom_data = chunk["data"]
|
||||
info(f"🎯 custom_data 内容: {repr(custom_data)}")
|
||||
info(f"🎯 custom_data 类型: {type(custom_data)}")
|
||||
|
||||
# 关键修复:处理我们从 react_reason_node 发送的自定义推理事件
|
||||
# LangGraph 的 adispatch_custom_event 发送的事件格式:
|
||||
# chunk["data"] 是我们传的第二个参数(dict)
|
||||
# 处理我们从 react_reason_node 发送的自定义推理事件
|
||||
if isinstance(custom_data, dict):
|
||||
# 检查是否是我们的推理事件
|
||||
if "action" in custom_data and "reasoning" in custom_data:
|
||||
info(f"[Agent Service] 收到自定义推理事件: {custom_data}")
|
||||
yield {
|
||||
"type": "react_reasoning",
|
||||
"step": custom_data.get("step", 1),
|
||||
@@ -339,7 +317,10 @@ class AIAgentService:
|
||||
if processed_event:
|
||||
yield processed_event
|
||||
|
||||
# 完整消息集合完成后,一次性打印
|
||||
info(f"✅ graph.astream() 完成,共 {chunk_count} 个chunks")
|
||||
if full_message_content:
|
||||
info(f"📄 完整消息内容: {repr(full_message_content)}")
|
||||
|
||||
except Exception as e:
|
||||
error(f"❌ 执行 React 图时出错: {e}")
|
||||
|
||||
@@ -12,18 +12,21 @@ class ThreadHistoryService:
|
||||
def __init__(self, checkpointer):
|
||||
self.checkpointer = checkpointer
|
||||
|
||||
async def get_user_threads(self, user_id: str, limit: int = 50) -> List[Dict[str, Any]]:
|
||||
async def get_user_threads(self, user_id: str, limit: int = 4) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取指定用户的所有线程摘要信息
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
limit: 返回数量限制
|
||||
limit: 返回数量限制(强制最多4条)
|
||||
|
||||
Returns:
|
||||
线程列表,每个包含 thread_id, last_updated, summary, message_count
|
||||
"""
|
||||
try:
|
||||
# 强制限制最多4条
|
||||
actual_limit = min(limit, 4)
|
||||
|
||||
# 查询 checkpoints 表获取用户的线程列表
|
||||
async with self.checkpointer.conn.cursor() as cur:
|
||||
# 在较新的 LangGraph 版本中,AsyncPostgresSaver 创建的 checkpoints 表
|
||||
@@ -40,7 +43,7 @@ class ThreadHistoryService:
|
||||
ORDER BY last_updated DESC
|
||||
LIMIT %s
|
||||
"""
|
||||
await cur.execute(query, (user_id, limit))
|
||||
await cur.execute(query, (user_id, actual_limit))
|
||||
rows = await cur.fetchall()
|
||||
|
||||
threads = []
|
||||
|
||||
@@ -98,7 +98,7 @@ async def health_check():
|
||||
class ChatRequest(BaseModel):
|
||||
message: str
|
||||
thread_id: str | None = None
|
||||
model: str = "zhipu"
|
||||
model: str = "local"
|
||||
user_id: str = "default_user"
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
@@ -212,7 +212,7 @@ async def chat_endpoint(
|
||||
@app.get("/threads")
|
||||
async def list_threads(
|
||||
user_id: str = Query("default_user", description="用户 ID"),
|
||||
limit: int = Query(50, ge=1, le=200, description="返回数量限制"),
|
||||
limit: int = Query(4, ge=1, le=200, description="返回数量限制"),
|
||||
history_service: ThreadHistoryService = Depends(get_history_service)
|
||||
):
|
||||
"""获取当前用户的对话历史列表"""
|
||||
@@ -312,7 +312,7 @@ async def websocket_endpoint(
|
||||
data = await websocket.receive_json()
|
||||
message = data.get("message")
|
||||
thread_id = data.get("thread_id", str(uuid.uuid4()))
|
||||
model = data.get("model", "zhipu")
|
||||
model = data.get("model", "local")
|
||||
user_id = data.get("user_id", "default_user")
|
||||
if not message:
|
||||
await websocket.send_json({"error": "missing message"})
|
||||
@@ -435,4 +435,10 @@ if __name__ == "__main__":
|
||||
import uvicorn
|
||||
# 使用环境变量或默认端口 8079(避免与 llama.cpp 的 8081 端口冲突)
|
||||
port = int(BACKEND_PORT)
|
||||
uvicorn.run(app, host="0.0.0.0", port=port)
|
||||
uvicorn.run(
|
||||
app,
|
||||
host="0.0.0.0",
|
||||
port=port,
|
||||
log_level="debug",
|
||||
access_log=True
|
||||
)
|
||||
|
||||
@@ -130,12 +130,16 @@ class ReactIntentReasoner:
|
||||
retrieved_docs = context.get("retrieved_docs", [])
|
||||
messages = context.get("messages", [])
|
||||
|
||||
# 关键修复 2:如果已经有 rag_context 或 web_search_results(通过 messages 推断),直接回答
|
||||
# 检查是否已经执行过 rag_retrieve 或 web_search
|
||||
if "rag_retrieve" in previous_actions or "web_search" in previous_actions:
|
||||
# 关键修改:不要在第一次 rag_retrieve 后就直接回答,允许再推理一次
|
||||
# 让推理逻辑有机会判断 RAG 结果好不好,要不要再检索或转 web search
|
||||
rag_count = previous_actions.count("rag_retrieve")
|
||||
web_search_count = previous_actions.count("web_search")
|
||||
|
||||
# 只有当 rag 或 web search 已经超过 1 次,或者已经有推理在 rag 之后,才直接回答
|
||||
if rag_count >= 2 or web_search_count >= 1:
|
||||
result.action = ReasoningAction.DIRECT_RESPONSE
|
||||
result.confidence = 0.95
|
||||
result.reasoning = "已获取信息,直接回答"
|
||||
result.reasoning = "已获取足够信息,直接回答"
|
||||
return result
|
||||
|
||||
# 策略1:尝试使用 LLM 推理
|
||||
|
||||
@@ -95,10 +95,10 @@ def inject_rag_tool_to_state(state: MainGraphState, rag_tool: Any) -> MainGraphS
|
||||
return state
|
||||
|
||||
|
||||
# ========== RAG 检索核心逻辑(真正利用已有代码)==========
|
||||
def _rag_retrieve_core(state: MainGraphState) -> MainGraphState:
|
||||
# ========== RAG 检索核心逻辑(真正利用已有代码) ==========
|
||||
async def _rag_retrieve_core(state: MainGraphState) -> MainGraphState:
|
||||
"""
|
||||
RAG 检索核心逻辑(真正利用 rag/tools.py)
|
||||
RAG 检索核心逻辑(真正利用 rag/tools.py) - 异步版本
|
||||
|
||||
Args:
|
||||
state: 主图状态
|
||||
@@ -119,10 +119,10 @@ def _rag_retrieve_core(state: MainGraphState) -> MainGraphState:
|
||||
rag_tool = get_rag_tool_from_state(state)
|
||||
|
||||
if rag_tool:
|
||||
# 使用真正的 RAG 工具(来自 rag/tools.py)
|
||||
# 使用真正的 RAG 工具(来自 rag/tools.py)- 异步版本
|
||||
try:
|
||||
# 调用 LangChain Tool 的 invoke 方法
|
||||
rag_context = rag_tool.invoke(retrieval_query)
|
||||
# 直接 await 异步工具的 ainvoke 方法
|
||||
rag_context = await rag_tool.ainvoke(retrieval_query)
|
||||
state.rag_context = rag_context
|
||||
state.rag_docs = [
|
||||
{"source": "rag_retrieval", "content": rag_context}
|
||||
@@ -134,9 +134,9 @@ def _rag_retrieve_core(state: MainGraphState) -> MainGraphState:
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"RAG 工具调用失败: {str(e)}") from e
|
||||
elif _GLOBAL_RAG_PIPELINE:
|
||||
# 使用 RAG Pipeline 直接检索
|
||||
# 使用 RAG Pipeline 直接检索 - 直接用异步方法
|
||||
try:
|
||||
documents = _GLOBAL_RAG_PIPELINE.retrieve(retrieval_query)
|
||||
documents = await _GLOBAL_RAG_PIPELINE.aretrieve(retrieval_query)
|
||||
if documents:
|
||||
rag_context = _GLOBAL_RAG_PIPELINE.format_context(documents)
|
||||
state.rag_context = rag_context
|
||||
@@ -158,7 +158,7 @@ def _rag_retrieve_core(state: MainGraphState) -> MainGraphState:
|
||||
raise RuntimeError("RAG 工具未初始化,请先调用 set_global_rag_tool() 或 set_global_rag_pipeline()")
|
||||
|
||||
|
||||
# ========== RAG 检索节点(带超时和重试)==========
|
||||
# ========== RAG 检索节点(带超时和重试) ==========
|
||||
async def rag_retrieve_node(state: MainGraphState, config: Optional[Dict[str, Any]] = None) -> MainGraphState:
|
||||
"""
|
||||
RAG 检索节点:带超时和重试,真正利用已有 RAG 代码
|
||||
@@ -196,8 +196,13 @@ async def rag_retrieve_node(state: MainGraphState, config: Optional[Dict[str, An
|
||||
|
||||
for attempt in range(RAG_RETRY_CONFIG.max_retries + 1):
|
||||
try:
|
||||
# 执行核心逻辑
|
||||
result = _rag_retrieve_core(state)
|
||||
# 执行核心逻辑 - 异步 await
|
||||
result = await _rag_retrieve_core(state)
|
||||
|
||||
info(f"[rag_retrieve_node] RAG 检索成功,获取到上下文长度: {len(result.rag_context)} 字符")
|
||||
if result.rag_docs:
|
||||
for i, doc in enumerate(result.rag_docs[:3]): # 只显示前3条
|
||||
info(f"[rag_retrieve_node] 文档 {i+1}: {doc.get('content', '')[:100]}...")
|
||||
|
||||
# 成功
|
||||
state.debug_info["rag_retrieval"] = {
|
||||
@@ -226,6 +231,15 @@ async def rag_retrieve_node(state: MainGraphState, config: Optional[Dict[str, An
|
||||
except Exception as e:
|
||||
info(f"[rag_retrieve_node] 无法发送完成事件: {e}")
|
||||
|
||||
# 关键修复:把 rag_retrieve 加到 reasoning_history 里,让下次推理知道
|
||||
state.reasoning_history.append({
|
||||
"step": state.reasoning_step,
|
||||
"action": "rag_retrieve",
|
||||
"confidence": 1.0,
|
||||
"reasoning": "RAG 检索完成",
|
||||
"timestamp": datetime.now().isoformat()
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
@@ -255,7 +269,7 @@ async def rag_retrieve_node(state: MainGraphState, config: Optional[Dict[str, An
|
||||
|
||||
# 指数退避等待
|
||||
delay = RAG_RETRY_CONFIG.base_delay * (2 ** attempt)
|
||||
time.sleep(min(delay, RAG_RETRY_CONFIG.max_delay))
|
||||
await asyncio.sleep(min(delay, RAG_RETRY_CONFIG.max_delay))
|
||||
|
||||
# 所有重试都失败,记录结构化错误
|
||||
error_record = ErrorRecord(
|
||||
|
||||
@@ -364,20 +364,27 @@ def route_by_reasoning(state: MainGraphState) -> str:
|
||||
if "subgraph_completed" in previous_actions or state.final_result:
|
||||
return "llm_call"
|
||||
|
||||
# 检查是否刚刚执行完 rag 或 web search,应该继续推理一次然后去 llm_call
|
||||
# 但为了避免死循环,我们设置一个简单的规则
|
||||
if len(previous_actions) > 3:
|
||||
# 关键修复:如果已经执行过 rag_retrieve 并且又执行过推理,直接去 LLM_CALL
|
||||
# 这样的流程:推理1 → RAG → 推理2(判断 RAG 结果) → LLM_CALL
|
||||
rag_count = previous_actions.count("rag_retrieve")
|
||||
if rag_count >= 1 and len(previous_actions) >= rag_count + 1:
|
||||
info(f"[route_by_reasoning] 已完成 RAG 检索和结果判断,直接去 llm_call")
|
||||
return "llm_call"
|
||||
|
||||
|
||||
# 关键修复:限制最多 3 次推理,避免无限循环
|
||||
if len(previous_actions) >= 3:
|
||||
info(f"[route_by_reasoning] 已达到最大推理次数 ({len(previous_actions)}),直接去 llm_call")
|
||||
return "llm_call"
|
||||
|
||||
# 获取推理结果
|
||||
reasoning_result: Optional[ReasoningResult] = state.debug_info.get("reasoning_result")
|
||||
|
||||
|
||||
if not reasoning_result:
|
||||
return "llm_call"
|
||||
|
||||
|
||||
# 使用 intent.py 提供的路由函数
|
||||
route = get_route_by_reasoning(reasoning_result)
|
||||
|
||||
|
||||
# 映射到我们的节点名称
|
||||
# 注意:这些名称必须与 main_graph_builder.py 中定义的节点名称一致
|
||||
route_mapping = {
|
||||
@@ -391,7 +398,8 @@ def route_by_reasoning(state: MainGraphState) -> str:
|
||||
"dictionary": "dictionary_subgraph",
|
||||
"news_analysis": "news_analysis_subgraph",
|
||||
}
|
||||
|
||||
|
||||
info(f"[route_by_reasoning] 推理结果={reasoning_result.action.name}, 路由={route_mapping.get(route, 'llm_call')}, 历史动作={previous_actions}")
|
||||
return route_mapping.get(route, "llm_call")
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# app/rag_initializer.py
|
||||
from app.rag.tools import create_rag_tool_sync
|
||||
from app.rag.tools import create_rag_tool_sync, create_rag_tool_async
|
||||
from rag_core import create_parent_retriever
|
||||
from app.model_services import get_embedding_service
|
||||
from app.logger import info, warning
|
||||
@@ -16,11 +16,11 @@ async def init_rag_tool(local_llm_creator):
|
||||
embeddings=embeddings
|
||||
)
|
||||
rewrite_llm = local_llm_creator()
|
||||
rag_tool = create_rag_tool_sync(
|
||||
rag_tool = create_rag_tool_async(
|
||||
retriever, rewrite_llm,
|
||||
num_queries=3, rerank_top_n=5
|
||||
)
|
||||
info("✅ RAG 检索工具初始化成功")
|
||||
info("✅ RAG 检索工具初始化成功(异步版本)")
|
||||
return rag_tool
|
||||
except Exception as e:
|
||||
warning(f"⚠️ RAG 检索工具初始化失败: {e}")
|
||||
|
||||
@@ -70,6 +70,63 @@ def create_rag_tool_sync(
|
||||
return search_knowledge_base_sync
|
||||
|
||||
|
||||
def create_rag_tool_async(
|
||||
retriever: Optional[BaseRetriever] = None,
|
||||
llm: Optional[BaseLanguageModel] = None,
|
||||
num_queries: int = 3,
|
||||
rerank_top_n: int = 5,
|
||||
collection_name: str = "rag_documents",
|
||||
) -> Callable:
|
||||
"""
|
||||
创建一个配置好的 RAG 检索工具(异步版本)。
|
||||
|
||||
默认使用混合检索(稠密+BM25稀疏)+ 父子文档模式。
|
||||
|
||||
Args:
|
||||
retriever: 基础检索器对象(可选,不提供则自动创建)
|
||||
llm: 用于生成多路查询的语言模型(可选)
|
||||
num_queries: 生成的查询变体数量
|
||||
rerank_top_n: 最终返回的文档数量
|
||||
collection_name: Qdrant 集合名称
|
||||
|
||||
Returns:
|
||||
Async LangChain Tool 函数
|
||||
"""
|
||||
pipeline = RAGPipeline(
|
||||
retriever=retriever,
|
||||
llm=llm,
|
||||
num_queries=num_queries,
|
||||
rerank_top_n=rerank_top_n,
|
||||
collection_name=collection_name,
|
||||
)
|
||||
|
||||
@tool
|
||||
async def search_knowledge_base_async(query: str) -> str:
|
||||
"""
|
||||
在知识库中搜索与查询相关的文档片段(异步版本)。
|
||||
|
||||
使用混合检索(稠密向量语义 + BM25 关键词)+ 父子文档模式,
|
||||
检索效果最优。
|
||||
|
||||
Args:
|
||||
query: 用户提出的问题或查询字符串
|
||||
|
||||
Returns:
|
||||
格式化后的相关文档内容
|
||||
"""
|
||||
try:
|
||||
documents = await pipeline.aretrieve(query)
|
||||
if not documents:
|
||||
return f"在知识库 '{collection_name}' 中未找到与 '{query}' 相关的信息。"
|
||||
|
||||
context = pipeline.format_context(documents)
|
||||
return context
|
||||
except Exception as e:
|
||||
return f"检索过程中发生错误: {str(e)}"
|
||||
|
||||
return search_knowledge_base_async
|
||||
|
||||
|
||||
def create_rag_tool(
|
||||
collection_name: str = "rag_documents",
|
||||
llm: Optional[BaseLanguageModel] = None,
|
||||
|
||||
Reference in New Issue
Block a user