本地RAG尝试
This commit is contained in:
82
app/agent.py
82
app/agent.py
@@ -6,14 +6,40 @@ AI Agent 服务类 - 支持多模型动态切换
|
||||
import os
|
||||
import json
|
||||
from dotenv import load_dotenv
|
||||
from langchain_community.chat_models import ChatZhipuAI
|
||||
from langchain_openai import ChatOpenAI
|
||||
try:
|
||||
from langchain_community.chat_models import ChatZhipuAI
|
||||
HAS_ZHIPUAI = True
|
||||
except ImportError:
|
||||
HAS_ZHIPUAI = False
|
||||
ChatZhipuAI = None
|
||||
|
||||
try:
|
||||
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
||||
HAS_OPENAI = True
|
||||
except ImportError:
|
||||
HAS_OPENAI = False
|
||||
ChatOpenAI = None
|
||||
OpenAIEmbeddings = None
|
||||
|
||||
from pydantic import SecretStr
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
try:
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
HAS_POSTGRES_CHECKPOINT = True
|
||||
except ImportError:
|
||||
HAS_POSTGRES_CHECKPOINT = False
|
||||
AsyncPostgresSaver = None
|
||||
|
||||
# 本地模块
|
||||
from app.graph_builder import GraphBuilder, GraphContext
|
||||
from app.tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
|
||||
try:
|
||||
from app.rag import RAGPipeline
|
||||
from app.rag.tools import RAGTool
|
||||
HAS_RAG = True
|
||||
except ImportError as e:
|
||||
HAS_RAG = False
|
||||
RAGPipeline = None
|
||||
RAGTool = None
|
||||
from app.logger import debug, info, warning, error
|
||||
|
||||
|
||||
@@ -31,9 +57,13 @@ class AIAgentService:
|
||||
"""
|
||||
self.checkpointer = checkpointer
|
||||
self.graphs = {} # 存储不同模型对应的 graph 实例
|
||||
self.rag = None # RAG 检索实例
|
||||
self.rag_tool = None # RAG 工具实例
|
||||
|
||||
def _create_zhipu_llm(self):
|
||||
"""创建智谱在线 LLM"""
|
||||
if not HAS_ZHIPUAI:
|
||||
raise ImportError("智谱AI支持不可用,请安装langchain-community包")
|
||||
api_key = os.getenv("ZHIPUAI_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError("ZHIPUAI_API_KEY not set in environment")
|
||||
@@ -49,6 +79,8 @@ class AIAgentService:
|
||||
|
||||
def _create_deepseek_llm(self):
|
||||
"""创建 DeepSeek LLM(使用 OpenAI 兼容 API)"""
|
||||
if not HAS_OPENAI:
|
||||
raise ImportError("OpenAI兼容支持不可用,请安装langchain-openai包")
|
||||
api_key = os.getenv("DEEPSEEK_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError("DEEPSEEK_API_KEY not set in environment")
|
||||
@@ -65,6 +97,8 @@ class AIAgentService:
|
||||
|
||||
def _create_local_llm(self):
|
||||
"""创建本地 vLLM 服务 LLM"""
|
||||
if not HAS_OPENAI:
|
||||
raise ImportError("OpenAI兼容支持不可用,请安装langchain-openai包")
|
||||
# vLLM 服务地址:优先从环境变量读取,适配 Docker、FRP 穿透和本地开发
|
||||
vllm_base_url = os.getenv(
|
||||
"VLLM_BASE_URL",
|
||||
@@ -80,8 +114,39 @@ class AIAgentService:
|
||||
streaming=True, # 确保开启流式输出
|
||||
)
|
||||
|
||||
def _create_embeddings(self):
|
||||
"""创建嵌入模型"""
|
||||
if not HAS_OPENAI:
|
||||
raise ImportError("OpenAI兼容支持不可用,请安装langchain-openai包")
|
||||
embedding_url = os.getenv(
|
||||
"LLAMACPP_EMBEDDING_URL",
|
||||
"http://127.0.0.1:8082/v1"
|
||||
)
|
||||
return OpenAIEmbeddings(
|
||||
openai_api_base=embedding_url,
|
||||
openai_api_key=os.getenv("LLAMACPP_API_KEY", "token-abc123"),
|
||||
model="text-embedding-ada-002", # 模型名称不重要,兼容即可
|
||||
)
|
||||
|
||||
async def initialize(self):
|
||||
"""预编译所有模型的 graph(使用传入的 checkpointer)"""
|
||||
# 先初始化 RAG 检索系统
|
||||
if HAS_RAG and RAGPipeline is not None and RAGTool is not None:
|
||||
try:
|
||||
info("🔄 正在初始化 RAG 检索系统...")
|
||||
embeddings = self._create_embeddings()
|
||||
self.rag = RAGPipeline(embeddings=embeddings)
|
||||
self.rag_tool = RAGTool(self.rag).get_tool()
|
||||
info("✅ RAG 检索系统初始化成功")
|
||||
except Exception as e:
|
||||
warning(f"⚠️ RAG 检索系统初始化失败: {e}")
|
||||
self.rag = None
|
||||
self.rag_tool = None
|
||||
else:
|
||||
info("⏭️ RAG 检索系统不可用,跳过初始化")
|
||||
self.rag = None
|
||||
self.rag_tool = None
|
||||
|
||||
model_configs = {
|
||||
"local": self._create_local_llm, # 本地模型作为第一个
|
||||
"deepseek": self._create_deepseek_llm, # DeepSeek 作为中间
|
||||
@@ -92,7 +157,16 @@ class AIAgentService:
|
||||
try:
|
||||
info(f"🔄 正在初始化模型 '{model_name}'...")
|
||||
llm = llm_creator()
|
||||
builder = GraphBuilder(llm, AVAILABLE_TOOLS, TOOLS_BY_NAME).build()
|
||||
|
||||
# 构建工具列表:基础工具 + RAG工具(如果可用)
|
||||
tools = AVAILABLE_TOOLS.copy()
|
||||
tools_by_name = TOOLS_BY_NAME.copy()
|
||||
|
||||
if self.rag_tool is not None:
|
||||
tools.append(self.rag_tool)
|
||||
tools_by_name[self.rag_tool.name] = self.rag_tool
|
||||
|
||||
builder = GraphBuilder(llm, tools, tools_by_name).build()
|
||||
graph = builder.compile(checkpointer=self.checkpointer)
|
||||
self.graphs[model_name] = graph
|
||||
info(f"✅ 模型 '{model_name}' 初始化成功")
|
||||
|
||||
Reference in New Issue
Block a user