本地RAG尝试

This commit is contained in:
2026-04-18 16:31:48 +08:00
parent 6042d4a476
commit 0470afce13
12 changed files with 1587 additions and 4 deletions

View File

@@ -6,14 +6,40 @@ AI Agent 服务类 - 支持多模型动态切换
import os
import json
from dotenv import load_dotenv
from langchain_community.chat_models import ChatZhipuAI
from langchain_openai import ChatOpenAI
try:
from langchain_community.chat_models import ChatZhipuAI
HAS_ZHIPUAI = True
except ImportError:
HAS_ZHIPUAI = False
ChatZhipuAI = None
try:
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
HAS_OPENAI = True
except ImportError:
HAS_OPENAI = False
ChatOpenAI = None
OpenAIEmbeddings = None
from pydantic import SecretStr
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
try:
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
HAS_POSTGRES_CHECKPOINT = True
except ImportError:
HAS_POSTGRES_CHECKPOINT = False
AsyncPostgresSaver = None
# 本地模块
from app.graph_builder import GraphBuilder, GraphContext
from app.tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
try:
from app.rag import RAGPipeline
from app.rag.tools import RAGTool
HAS_RAG = True
except ImportError as e:
HAS_RAG = False
RAGPipeline = None
RAGTool = None
from app.logger import debug, info, warning, error
@@ -31,9 +57,13 @@ class AIAgentService:
"""
self.checkpointer = checkpointer
self.graphs = {} # 存储不同模型对应的 graph 实例
self.rag = None # RAG 检索实例
self.rag_tool = None # RAG 工具实例
def _create_zhipu_llm(self):
"""创建智谱在线 LLM"""
if not HAS_ZHIPUAI:
raise ImportError("智谱AI支持不可用请安装langchain-community包")
api_key = os.getenv("ZHIPUAI_API_KEY")
if not api_key:
raise ValueError("ZHIPUAI_API_KEY not set in environment")
@@ -49,6 +79,8 @@ class AIAgentService:
def _create_deepseek_llm(self):
"""创建 DeepSeek LLM使用 OpenAI 兼容 API"""
if not HAS_OPENAI:
raise ImportError("OpenAI兼容支持不可用请安装langchain-openai包")
api_key = os.getenv("DEEPSEEK_API_KEY")
if not api_key:
raise ValueError("DEEPSEEK_API_KEY not set in environment")
@@ -65,6 +97,8 @@ class AIAgentService:
def _create_local_llm(self):
"""创建本地 vLLM 服务 LLM"""
if not HAS_OPENAI:
raise ImportError("OpenAI兼容支持不可用请安装langchain-openai包")
# vLLM 服务地址:优先从环境变量读取,适配 Docker、FRP 穿透和本地开发
vllm_base_url = os.getenv(
"VLLM_BASE_URL",
@@ -80,8 +114,39 @@ class AIAgentService:
streaming=True, # 确保开启流式输出
)
def _create_embeddings(self):
"""创建嵌入模型"""
if not HAS_OPENAI:
raise ImportError("OpenAI兼容支持不可用请安装langchain-openai包")
embedding_url = os.getenv(
"LLAMACPP_EMBEDDING_URL",
"http://127.0.0.1:8082/v1"
)
return OpenAIEmbeddings(
openai_api_base=embedding_url,
openai_api_key=os.getenv("LLAMACPP_API_KEY", "token-abc123"),
model="text-embedding-ada-002", # 模型名称不重要,兼容即可
)
async def initialize(self):
"""预编译所有模型的 graph使用传入的 checkpointer"""
# 先初始化 RAG 检索系统
if HAS_RAG and RAGPipeline is not None and RAGTool is not None:
try:
info("🔄 正在初始化 RAG 检索系统...")
embeddings = self._create_embeddings()
self.rag = RAGPipeline(embeddings=embeddings)
self.rag_tool = RAGTool(self.rag).get_tool()
info("✅ RAG 检索系统初始化成功")
except Exception as e:
warning(f"⚠️ RAG 检索系统初始化失败: {e}")
self.rag = None
self.rag_tool = None
else:
info("⏭️ RAG 检索系统不可用,跳过初始化")
self.rag = None
self.rag_tool = None
model_configs = {
"local": self._create_local_llm, # 本地模型作为第一个
"deepseek": self._create_deepseek_llm, # DeepSeek 作为中间
@@ -92,7 +157,16 @@ class AIAgentService:
try:
info(f"🔄 正在初始化模型 '{model_name}'...")
llm = llm_creator()
builder = GraphBuilder(llm, AVAILABLE_TOOLS, TOOLS_BY_NAME).build()
# 构建工具列表:基础工具 + RAG工具如果可用
tools = AVAILABLE_TOOLS.copy()
tools_by_name = TOOLS_BY_NAME.copy()
if self.rag_tool is not None:
tools.append(self.rag_tool)
tools_by_name[self.rag_tool.name] = self.rag_tool
builder = GraphBuilder(llm, tools, tools_by_name).build()
graph = builder.compile(checkpointer=self.checkpointer)
self.graphs[model_name] = graph
info(f"✅ 模型 '{model_name}' 初始化成功")