This commit is contained in:
@@ -18,24 +18,20 @@ from backend.app.logger import debug, info, error
|
||||
def create_dynamic_llm_call_node(chat_services: Dict[str, BaseChatModel], tools: list):
|
||||
"""
|
||||
工厂函数:创建动态 LLM 调用节点(根据 state.current_model 选择模型)
|
||||
|
||||
|
||||
Args:
|
||||
chat_services: 模型名称 -> ChatModel 实例 的字典
|
||||
tools: 工具列表
|
||||
|
||||
tools: 工具列表(llm_call 不使用工具,只负责回答)
|
||||
|
||||
Returns:
|
||||
异步节点函数
|
||||
"""
|
||||
# 预构建所有模型的 tools 绑定(避免每次调用都 bind)
|
||||
bound_models: Dict[str, Any] = {}
|
||||
for name, llm in chat_services.items():
|
||||
if tools:
|
||||
bound_models[name] = llm.bind_tools(tools)
|
||||
else:
|
||||
bound_models[name] = llm
|
||||
|
||||
# 预构建 prompt
|
||||
prompt = create_system_prompt(tools)
|
||||
# llm_call 节点不使用工具,只负责生成回答
|
||||
# 直接使用原始模型,不绑定工具
|
||||
models = chat_services
|
||||
|
||||
# 预构建 prompt(不带工具描述)
|
||||
prompt = create_system_prompt()
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
@@ -70,14 +66,14 @@ def create_dynamic_llm_call_node(chat_services: Dict[str, BaseChatModel], tools:
|
||||
|
||||
# 动态选择模型
|
||||
model_name = getattr(state, "current_model", "")
|
||||
if not model_name or model_name not in bound_models:
|
||||
if not model_name or model_name not in models:
|
||||
# 回退到第一个可用模型
|
||||
fallback_name = next(iter(bound_models.keys()))
|
||||
fallback_name = next(iter(models.keys()))
|
||||
info(f"[llm_call] 模型 '{model_name}' 不可用,回退到 '{fallback_name}'")
|
||||
model_name = fallback_name
|
||||
|
||||
llm_with_tools = bound_models[model_name]
|
||||
info(f"[llm_call] 使用模型: {model_name}")
|
||||
|
||||
llm = models[model_name]
|
||||
info(f"[llm_call] 使用模型(无工具): {model_name}")
|
||||
|
||||
try:
|
||||
# 添加上下文到消息
|
||||
@@ -103,7 +99,7 @@ def create_dynamic_llm_call_node(chat_services: Dict[str, BaseChatModel], tools:
|
||||
|
||||
# 恢复为:手动进行 astream,并将所有的 chunk 拼接成最终的 response 返回。
|
||||
# LangGraph 会自动监听这期间产生的所有 token。
|
||||
chain = prompt | llm_with_tools
|
||||
chain = prompt | llm
|
||||
chunks = []
|
||||
info(f"[llm_call] 开始调用 LLM astream...")
|
||||
async for chunk in chain.astream(
|
||||
@@ -115,8 +111,13 @@ def create_dynamic_llm_call_node(chat_services: Dict[str, BaseChatModel], tools:
|
||||
):
|
||||
chunks.append(chunk)
|
||||
|
||||
info(f"[llm_call] LLM astream 完成,共收到 {len(chunks)} 个 chunks,info:{chunks}")
|
||||
info(f"[llm_call] LLM astream 完成,共收到 {len(chunks)} 个 chunks,info:{chunks[0].content[:50]}...{chunks[-1].content[:50]}")
|
||||
|
||||
# 将所有 chunk 合并成最终的 AIMessage
|
||||
if chunks:
|
||||
response = chunks[0].content
|
||||
for chunk in chunks[1:]:
|
||||
response = response + chunk.content
|
||||
# 将所有 chunk 合并成最终的 AIMessage
|
||||
if chunks:
|
||||
response = chunks[0]
|
||||
@@ -167,9 +168,6 @@ def create_dynamic_llm_call_node(chat_services: Dict[str, BaseChatModel], tools:
|
||||
debug(f"📋 [LLM统计] 详细用量: {token_usage}")
|
||||
debug("="*80 + "\n")
|
||||
|
||||
# 检查是否有工具调用
|
||||
has_tool_calls = hasattr(response, 'tool_calls') and len(response.tool_calls) > 0
|
||||
|
||||
result = {
|
||||
"messages": [response],
|
||||
"llm_calls": getattr(state, 'llm_calls', 0) + 1,
|
||||
@@ -179,7 +177,6 @@ def create_dynamic_llm_call_node(chat_services: Dict[str, BaseChatModel], tools:
|
||||
"final_result": response.content,
|
||||
"success": True,
|
||||
"current_phase": "done",
|
||||
"has_tool_calls": has_tool_calls,
|
||||
"current_model": model_name # 记录实际使用的模型
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user