From 0470afce1391cd747bb66bb476b294b3df8ca70f Mon Sep 17 00:00:00 2001 From: root <953994191@qq.com> Date: Sat, 18 Apr 2026 16:31:48 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9C=AC=E5=9C=B0RAG=E5=B0=9D=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/agent.py | 82 +++++++- app/rag/README.md | 136 ++++++++++++ app/rag/__init__.py | 22 ++ app/rag/example.py | 232 +++++++++++++++++++++ app/rag/pipeline.py | 341 +++++++++++++++++++++++++++++++ app/rag/query_transform.py | 193 +++++++++++++++++ app/rag/requirements.txt | 23 +++ app/rag/reranker.py | 141 +++++++++++++ app/rag/retriever.py | 144 +++++++++++++ app/rag/tools.py | 230 +++++++++++++++++++++ frontend/components/chat_area.py | 45 ++++ requirement.txt | 2 + 12 files changed, 1587 insertions(+), 4 deletions(-) create mode 100644 app/rag/README.md create mode 100644 app/rag/__init__.py create mode 100644 app/rag/example.py create mode 100644 app/rag/pipeline.py create mode 100644 app/rag/query_transform.py create mode 100644 app/rag/requirements.txt create mode 100644 app/rag/reranker.py create mode 100644 app/rag/retriever.py create mode 100644 app/rag/tools.py diff --git a/app/agent.py b/app/agent.py index 53d5715..fe191bc 100644 --- a/app/agent.py +++ b/app/agent.py @@ -6,14 +6,40 @@ AI Agent 服务类 - 支持多模型动态切换 import os import json from dotenv import load_dotenv -from langchain_community.chat_models import ChatZhipuAI -from langchain_openai import ChatOpenAI +try: + from langchain_community.chat_models import ChatZhipuAI + HAS_ZHIPUAI = True +except ImportError: + HAS_ZHIPUAI = False + ChatZhipuAI = None + +try: + from langchain_openai import ChatOpenAI, OpenAIEmbeddings + HAS_OPENAI = True +except ImportError: + HAS_OPENAI = False + ChatOpenAI = None + OpenAIEmbeddings = None + from pydantic import SecretStr -from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver +try: + from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver + HAS_POSTGRES_CHECKPOINT = True +except ImportError: + HAS_POSTGRES_CHECKPOINT = False + AsyncPostgresSaver = None # 本地模块 from app.graph_builder import GraphBuilder, GraphContext from app.tools import AVAILABLE_TOOLS, TOOLS_BY_NAME +try: + from app.rag import RAGPipeline + from app.rag.tools import RAGTool + HAS_RAG = True +except ImportError as e: + HAS_RAG = False + RAGPipeline = None + RAGTool = None from app.logger import debug, info, warning, error @@ -31,9 +57,13 @@ class AIAgentService: """ self.checkpointer = checkpointer self.graphs = {} # 存储不同模型对应的 graph 实例 + self.rag = None # RAG 检索实例 + self.rag_tool = None # RAG 工具实例 def _create_zhipu_llm(self): """创建智谱在线 LLM""" + if not HAS_ZHIPUAI: + raise ImportError("智谱AI支持不可用,请安装langchain-community包") api_key = os.getenv("ZHIPUAI_API_KEY") if not api_key: raise ValueError("ZHIPUAI_API_KEY not set in environment") @@ -49,6 +79,8 @@ class AIAgentService: def _create_deepseek_llm(self): """创建 DeepSeek LLM(使用 OpenAI 兼容 API)""" + if not HAS_OPENAI: + raise ImportError("OpenAI兼容支持不可用,请安装langchain-openai包") api_key = os.getenv("DEEPSEEK_API_KEY") if not api_key: raise ValueError("DEEPSEEK_API_KEY not set in environment") @@ -65,6 +97,8 @@ class AIAgentService: def _create_local_llm(self): """创建本地 vLLM 服务 LLM""" + if not HAS_OPENAI: + raise ImportError("OpenAI兼容支持不可用,请安装langchain-openai包") # vLLM 服务地址:优先从环境变量读取,适配 Docker、FRP 穿透和本地开发 vllm_base_url = os.getenv( "VLLM_BASE_URL", @@ -80,8 +114,39 @@ class AIAgentService: streaming=True, # 确保开启流式输出 ) + def _create_embeddings(self): + """创建嵌入模型""" + if not HAS_OPENAI: + raise ImportError("OpenAI兼容支持不可用,请安装langchain-openai包") + embedding_url = os.getenv( + "LLAMACPP_EMBEDDING_URL", + "http://127.0.0.1:8082/v1" + ) + return OpenAIEmbeddings( + openai_api_base=embedding_url, + openai_api_key=os.getenv("LLAMACPP_API_KEY", "token-abc123"), + model="text-embedding-ada-002", # 模型名称不重要,兼容即可 + ) + async def initialize(self): """预编译所有模型的 graph(使用传入的 checkpointer)""" + # 先初始化 RAG 检索系统 + if HAS_RAG and RAGPipeline is not None and RAGTool is not None: + try: + info("🔄 正在初始化 RAG 检索系统...") + embeddings = self._create_embeddings() + self.rag = RAGPipeline(embeddings=embeddings) + self.rag_tool = RAGTool(self.rag).get_tool() + info("✅ RAG 检索系统初始化成功") + except Exception as e: + warning(f"⚠️ RAG 检索系统初始化失败: {e}") + self.rag = None + self.rag_tool = None + else: + info("⏭️ RAG 检索系统不可用,跳过初始化") + self.rag = None + self.rag_tool = None + model_configs = { "local": self._create_local_llm, # 本地模型作为第一个 "deepseek": self._create_deepseek_llm, # DeepSeek 作为中间 @@ -92,7 +157,16 @@ class AIAgentService: try: info(f"🔄 正在初始化模型 '{model_name}'...") llm = llm_creator() - builder = GraphBuilder(llm, AVAILABLE_TOOLS, TOOLS_BY_NAME).build() + + # 构建工具列表:基础工具 + RAG工具(如果可用) + tools = AVAILABLE_TOOLS.copy() + tools_by_name = TOOLS_BY_NAME.copy() + + if self.rag_tool is not None: + tools.append(self.rag_tool) + tools_by_name[self.rag_tool.name] = self.rag_tool + + builder = GraphBuilder(llm, tools, tools_by_name).build() graph = builder.compile(checkpointer=self.checkpointer) self.graphs[model_name] = graph info(f"✅ 模型 '{model_name}' 初始化成功") diff --git a/app/rag/README.md b/app/rag/README.md new file mode 100644 index 0000000..19f27f4 --- /dev/null +++ b/app/rag/README.md @@ -0,0 +1,136 @@ +# 在线 RAG 检索与生成系统 (Online RAG Retriever) + +该模块负责 RAG 系统的阶段二:**在线检索与生成**。它接收用户提问,从知识库中检索出上下文,利用各种高级策略去噪、融合,并作为增强上下文输入给大语言模型 (LLM)。 + +## 📊 RAG-Fusion & 混合检索流水线示意图 + +```mermaid +graph TD + User((用户提问)) --> A[LLM 查询改写生成器] + + subgraph RAG-Fusion 核心流程 + A -->|改写为问题 1| B1[查询 1] + A -->|改写为问题 2| B2[查询 2] + A -->|原问题| B3[原始查询] + + B1 & B2 & B3 --> C[混合检索器 Hybrid Retriever
Dense Vector + BM25 Sparse] + + C --> D[多路召回结果合集 N=60条] + + D --> E{RRF 倒数排名融合去重} + end + + E -->|筛选出前 20 条| F[Cross-Encoder 重排器 Reranker] + + F -->|精细打分排序 Top 5| G[最终纯净上下文 Context] + + G --> H[将 Context 与原问题拼接输入大模型] + H --> I((LLM 生成最终回答)) +``` + +--- + +## 🎯 演进路线与算法详解 (Roadmap) + +### Level 1: 基础向量搜索 (Basic Similarity Search) +- **核心算法**: 近似最近邻搜索 (ANN, 常用 HNSW 算法)。将用户问题转化为向量后,计算它与库中向量的余弦相似度 (Cosine Similarity),取距离最近的 K 个块。 +- **优缺点**: 速度极快。但只能捕捉“语义相似”,如果用户搜索特定专有名词、编号、订单号,纯向量检索往往会失效(产生“幻觉”匹配)。 + +### Level 2: 混合检索与重排序 (Hybrid Search + Reranker) +混合检索旨在结合向量的“语义泛化”与关键词的“精准匹配”,随后利用重排序模型过滤噪声。 + +**1. 基础召回 (混合检索)** +- **核心原理**: 结合基于 HNSW 的 Dense Vector 相似度搜索与基于 TF-IDF 的 BM25 稀疏检索 (Sparse Vector)。 +- **实现指南**: 使用 `langchain_qdrant` 包中的 `Qdrant` 类连接数据库。通过调用 `Qdrant.from_existing_collection(...)` 实例化向量库,并使用 `.as_retriever(search_kwargs={"k": 20})` 方法生成基础检索器。Qdrant 底层会自动处理双路召回。 + +**2. 二次精排 (Cross-Encoder)** +- **核心原理**: 不同于双塔模型(分别算向量再求距离),交叉编码器将“用户问题 + 检索到的单例文档”拼接后整体输入 Transformer 模型,由模型直接输出 0~1 的相关性得分,精度极高。 +- **实现指南**: + - 使用 `sentence-transformers` 库加载本地轻量级重排模型(如 `BAAI/bge-reranker-base`)。 + - 引入 `langchain.retrievers.document_compressors` 包中的 `CrossEncoderReranker` 类包装该模型,设置参数 `top_n=5`。 + - 最后,使用 `langchain.retrievers` 包中的 `ContextualCompressionRetriever` 类,将 `base_compressor` (重排器) 和 `base_retriever` (基础检索器) 组合。 + - **如何调用**: 业务逻辑中直接对组合后的检索器调用 `.invoke(query)` 方法,即可一键完成“大范围召回 20 条 -> 逐一打分精排选 5 条”的去噪流水线。 + +### Level 3: RAG-Fusion (多路改写与倒数排名融合) +RAG-Fusion 通过大模型发散思维,将单一问题改写为多个相似问题,扩大搜索面,再利用数学统计算法合并结果。 + +**1. 多路查询改写** +- **核心原理**: 克服用户初始提问词不达意或视角受限的问题。 +- **实现指南**: 导入 `langchain.retrievers.multi_query` 包中的 `MultiQueryRetriever` 类。需向其提供一个已实例化的 LLM 对象(如基于 `ChatOpenAI` 封装的本地 VLLM 模型)。系统在底层会自动 Prompt 模型,将原始 `query` 转化为包含 3-5 个不同表述的查询列表。 + +**2. 倒数排名融合 (RRF)** +- **核心原理**: RRF (Reciprocal Rank Fusion) 是一种无需评分归一化的融合算法。公式为 $RRF\_score(d) = \sum_{q \in Q} \frac{1}{k + rank_q(d)}$,有效避免某一极端检索结果主导全局。 +- **实现指南**: + - 针对每个改写后的查询 $q$,分别调用精排检索器的 `.invoke(q)` 获取文档列表。 + - 使用 `langchain.retrievers` 中的 `EnsembleRetriever` 类(原生支持 RRF),或在代码中遍历收集到的 `Document` 对象,基于其排名 `rank` 累加得分,最终通过 Python 的 `set` 去重并提取 `doc.page_content`。 + +### Level 4: Agentic RAG / Self-RAG (智能体与自我反思) +- **核心原理**: 基于 LangGraph 的 ReAct (Reasoning and Acting) 状态机路由。大模型并非每次都去死板地执行检索,而是先判断问题:“这是闲聊?还是需要查知识库?”。如果是后者,模型输出一个 `ToolCall` 指令,触发检索。 +- **实现指南**: 请参考下方的**与现有系统整合调用**章节。 + +- **示意图**: + ```mermaid + sequenceDiagram + participant User + participant LangGraph Agent + participant RAG_Tool + participant Qdrant + + User->>LangGraph Agent: "公司报销流程是什么?" + LangGraph Agent->>LangGraph Agent: 思考: 这是一个内部规章问题,需要查资料 + LangGraph Agent->>RAG_Tool: ToolCall(search_knowledge_base, "公司报销流程") + RAG_Tool->>Qdrant: RAG-Fusion & 混合检索 + Qdrant-->>RAG_Tool: 原始分块 + RAG_Tool->>RAG_Tool: Cross-Encoder 重排过滤 + RAG_Tool-->>LangGraph Agent: 返回最相关的5条报销规定 + LangGraph Agent->>LangGraph Agent: 思考: 资料充分,开始撰写回答 + LangGraph Agent-->>User: "根据知识库规定,报销流程分为以下3步..." + ``` + +--- + +## 📦 所需依赖与安装 + +除了基础的 LangChain 包外,在线检索模块为了支持重排和稀疏检索,还需要安装: + +```bash +# 用于 Cross-Encoder 重排序模型 (如 BAAI/bge-reranker-base) +pip install sentence-transformers + +# 用于 BM25 关键词混合检索 +pip install rank_bm25 + +# 基础框架 +pip install langchain langchain-core langchain-openai langchain-qdrant +``` + +--- + +## 📂 架构与文件结构设计 + +在 `app/rag/` 目录下,需创建以下文件来模块化上述功能: + +```text +app/rag/ +├── __init__.py +├── retriever.py # 负责 Qdrant 的基础召回与 ContextualCompressionRetriever +├── reranker.py # 负责加载 sentence-transformers 交叉编码器 +├── query_transform.py # 负责基于 MultiQueryRetriever 的改写逻辑 +├── pipeline.py # 组合上述组件,暴露出核心的 retrieve() 方法 +└── tools.py # 将 Pipeline 包装成 LangChain Tool 供 Agent 调用 +``` + +--- + +## � 与现有系统整合调用 (Agentic RAG 实现) + +基于目前 LangGraph 系统的架构,我们将摒弃将代码堆砌在一起的旧方式,而是利用 **LangChain Tools** 的特性将 RAG 优雅地注入系统: + +1. **封装检索工具 (Tool)**: + 从 `langchain.tools` 导入 `@tool` 装饰器。定义一个名为 `search_knowledge_base(query: str)` 的函数。在函数内部,实例化并调用我们在 `pipeline.py` 中写好的多路召回与重排逻辑。 +2. **模型绑定 (Bind)**: + 在 `app/agent.py` 或 `app/nodes/tool_call.py` 中,将这个工具引入,并通过 `llm.bind_tools([search_knowledge_base])` 绑定到现有的本地大模型实例上。 +3. **状态机路由 (Graph Routing)**: + 你的 LangGraph 状态机会像处理普通对话一样自动接管:当模型判断需要调用查阅规章制度或专业资料时,它会输出 `ToolCall` 消息,流转到 `tool_node` 执行上述的 RAG 检索逻辑并返回上下文。 + +这让你无需修改任何前端 Streamlit 流式代码,就能平滑升级为具备超级知识库检索能力的智能体 (Agent)! diff --git a/app/rag/__init__.py b/app/rag/__init__.py new file mode 100644 index 0000000..0d44285 --- /dev/null +++ b/app/rag/__init__.py @@ -0,0 +1,22 @@ +""" +在线 RAG 检索与生成系统 + +提供高级RAG检索功能,支持混合检索、重排序、RAG-Fusion和多路查询改写。 +""" + +from .pipeline import RAGPipeline +from .retriever import create_hybrid_retriever, create_base_retriever +from .reranker import CrossEncoderReranker +from .query_transform import MultiQueryTransformer +from .tools import search_knowledge_base_tool + +__all__ = [ + "RAGPipeline", + "create_hybrid_retriever", + "create_base_retriever", + "CrossEncoderReranker", + "MultiQueryTransformer", + "search_knowledge_base_tool", +] + +__version__ = "0.1.0" \ No newline at end of file diff --git a/app/rag/example.py b/app/rag/example.py new file mode 100644 index 0000000..53d82fe --- /dev/null +++ b/app/rag/example.py @@ -0,0 +1,232 @@ +#!/usr/bin/env python3 +""" +RAG 系统使用示例 + +演示如何使用 app/rag 模块进行知识检索。 +""" + +import sys +import os + +# 添加项目根目录到路径 +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) + +from langchain_openai import OpenAIEmbeddings +from langchain_community.llms import VLLMOpenAI + + +def setup_environment(): + """设置环境变量""" + # 设置 Qdrant 连接信息(根据实际情况修改) + os.environ.setdefault("QDRANT_URL", "http://115.190.121.151:6333") + # 如果需要 API 密钥,请设置 QDRANT_API_KEY + + print("环境变量已设置") + print(f"QDRANT_URL: {os.environ.get('QDRANT_URL')}") + + +def demonstrate_basic_rag(): + """演示基础 RAG 功能""" + print("\n" + "="*60) + print("演示: 基础 RAG 检索 (Level 1)") + print("="*60) + + # 创建嵌入模型(使用 OpenAI 兼容的本地模型) + embeddings = OpenAIEmbeddings( + openai_api_base="http://localhost:8000/v1", # 本地 VLLM 服务 + openai_api_key="no-key-needed", + model="text-embedding-ada-002", # 假设的模型名称 + ) + + # 创建 RAG 流水线 + from app.rag import RAGPipeline, RAGConfig, RAGLevel + + config = RAGConfig( + collection_name="documents", # 你的集合名称 + rag_level=RAGLevel.BASIC, + verbose=True, + ) + + pipeline = RAGPipeline( + embeddings=embeddings, + config=config, + ) + + # 示例查询 + query = "公司报销流程是什么?" + print(f"\n查询: {query}") + + try: + result = pipeline.retrieve(query) + print(f"找到 {len(result.documents)} 个相关文档") + + # 格式化上下文 + context = pipeline.format_context(result.documents) + print(f"\n上下文预览:\n{context[:500]}...") + + except Exception as e: + print(f"检索失败: {e}") + print("请确保 Qdrant 服务正常运行且集合存在") + + +def demonstrate_hybrid_rag(): + """演示混合 RAG 功能""" + print("\n" + "="*60) + print("演示: 混合 RAG 检索 (Level 2)") + print("="*60) + + embeddings = OpenAIEmbeddings( + openai_api_base="http://localhost:8000/v1", + openai_api_key="no-key-needed", + model="text-embedding-ada-002", + ) + + from app.rag import RAGPipeline, RAGConfig, RAGLevel + + config = RAGConfig( + collection_name="documents", + rag_level=RAGLevel.HYBRID, + dense_k=10, + sparse_k=10, + rerank_top_n=5, + verbose=True, + ) + + pipeline = RAGPipeline( + embeddings=embeddings, + config=config, + ) + + query = "如何申请年假?" + print(f"\n查询: {query}") + + try: + result = pipeline.retrieve(query) + print(f"找到 {len(result.documents)} 个重排序后的文档") + + except Exception as e: + print(f"检索失败: {e}") + + +def demonstrate_rag_fusion(): + """演示 RAG-Fusion 功能""" + print("\n" + "="*60) + print("演示: RAG-Fusion (Level 3)") + print("="*60) + + embeddings = OpenAIEmbeddings( + openai_api_base="http://localhost:8000/v1", + openai_api_key="no-key-needed", + model="text-embedding-ada-002", + ) + + # 创建语言模型用于查询改写 + llm = VLLMOpenAI( + openai_api_base="http://localhost:8000/v1", + openai_api_key="no-key-needed", + model_name="Qwen2.5-7B-Instruct", # 你的本地模型 + temperature=0.3, + max_tokens=512, + ) + + from app.rag import RAGPipeline, RAGConfig, RAGLevel + + config = RAGConfig( + collection_name="documents", + rag_level=RAGLevel.FUSION, + num_queries=3, + verbose=True, + ) + + pipeline = RAGPipeline( + embeddings=embeddings, + llm=llm, + config=config, + ) + + query = "项目上线需要哪些审批?" + print(f"\n查询: {query}") + + try: + result = pipeline.retrieve(query) + print(f"找到 {len(result.documents)} 个文档 (经过多路查询改写和重排序)") + + except Exception as e: + print(f"检索失败: {e}") + + +def demonstrate_agentic_rag(): + """演示 Agentic RAG 功能""" + print("\n" + "="*60) + print("演示: Agentic RAG (Level 4)") + print("="*60) + + embeddings = OpenAIEmbeddings( + openai_api_base="http://localhost:8000/v1", + openai_api_key="no-key-needed", + model="text-embedding-ada-002", + ) + + llm = VLLMOpenAI( + openai_api_base="http://localhost:8000/v1", + openai_api_key="no-key-needed", + model_name="Qwen2.5-7B-Instruct", + temperature=0.3, + max_tokens=512, + ) + + from app.rag import create_agentic_rag_pipeline + + try: + # 创建 Agentic RAG 流水线 + agentic_rag = create_agentic_rag_pipeline( + embeddings=embeddings, + agent_llm=llm, + config={ + "collection_name": "documents", + "verbose": True, + }, + ) + + print("Agentic RAG 流水线创建成功!") + print(f"- 绑定的模型: {agentic_rag['llm']}") + print(f"- RAG 工具: {agentic_rag['tool'].name}") + + # 演示工具调用 + print("\n工具调用示例:") + response = agentic_rag["tool"].invoke({"query": "员工福利有哪些?"}) + print(f"工具响应预览: {response[:200]}...") + + except Exception as e: + print(f"创建 Agentic RAG 失败: {e}") + import traceback + traceback.print_exc() + + +def main(): + """主函数""" + print("RAG 系统演示") + print("="*60) + + # 设置环境 + setup_environment() + + # 演示各级功能 + demonstrate_basic_rag() + demonstrate_hybrid_rag() + demonstrate_rag_fusion() + demonstrate_agentic_rag() + + print("\n" + "="*60) + print("演示完成!") + print("="*60) + + print("\n使用说明:") + print("1. 确保 Qdrant 服务运行且集合已创建") + print("2. 根据需要修改 embeddings 和 llm 配置") + print("3. 在 Agent 系统中导入并使用 app.rag.tools.search_knowledge_base_tool") + print("4. 将工具绑定到你的 Agent 模型") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/app/rag/pipeline.py b/app/rag/pipeline.py new file mode 100644 index 0000000..0e6d0bc --- /dev/null +++ b/app/rag/pipeline.py @@ -0,0 +1,341 @@ +""" +RAG 检索流水线 + +组合检索器、重排序器、查询改写器等组件,提供完整的 RAG 检索功能。 +""" + +import time +from typing import List, Dict, Any, Optional, Union +from dataclasses import dataclass, field +from enum import Enum +from langchain_core.documents import Document +from langchain_core.embeddings import Embeddings +from langchain_core.language_models import BaseLanguageModel + +from .retriever import ( + create_base_retriever, + create_hybrid_retriever, + create_ensemble_retriever, + create_qdrant_client, +) +from .reranker import CrossEncoderReranker +from .query_transform import MultiQueryTransformer, create_rag_fusion_pipeline + + +class RAGLevel(Enum): + """RAG 功能级别""" + BASIC = 1 # 基础向量搜索 + HYBRID = 2 # 混合检索 + 重排序 + FUSION = 3 # RAG-Fusion + AGENTIC = 4 # Agentic RAG + + +@dataclass +class RAGConfig: + """RAG 配置""" + # Qdrant 配置 + collection_name: str = "documents" + qdrant_url: Optional[str] = None + qdrant_api_key: Optional[str] = None + + # 检索配置 + rag_level: RAGLevel = RAGLevel.FUSION + dense_k: int = 10 # 向量检索数量 + sparse_k: int = 10 # BM25 检索数量 + total_k: int = 20 # 总检索数量 + rerank_top_n: int = 5 # 重排序返回数量 + + # 查询改写配置 + num_queries: int = 3 # RAG-Fusion 查询数量 + + # 模型配置 + reranker_model: str = "BAAI/bge-reranker-base" + device: Optional[str] = None + + # 性能配置 + enable_cache: bool = True + verbose: bool = True + + +@dataclass +class RetrievalResult: + """检索结果""" + documents: List[Document] + query_time: float + level: RAGLevel + metadata: Dict[str, Any] = field(default_factory=dict) + + +class RAGPipeline: + """ + RAG 检索流水线 + + 支持从 Level 1 到 Level 4 的所有功能。 + """ + + def __init__( + self, + embeddings: Embeddings, + llm: Optional[BaseLanguageModel] = None, + config: Optional[RAGConfig] = None, + ): + """ + 初始化 RAG 流水线 + + Args: + embeddings: 嵌入模型 + llm: 语言模型(用于查询改写,Level 3+ 需要) + config: 配置 + """ + self.embeddings = embeddings + self.llm = llm + self.config = config or RAGConfig() + + # 初始化组件 + self._client = None + self._reranker = None + self._query_transformer = None + self._retriever = None + + # 缓存 + self._cache = {} + + def _get_client(self): + """获取 Qdrant 客户端""" + if self._client is None: + self._client = create_qdrant_client( + url=self.config.qdrant_url, + api_key=self.config.qdrant_api_key, + ) + return self._client + + def _get_reranker(self): + """获取重排序器""" + if self._reranker is None: + self._reranker = CrossEncoderReranker( + model_name=self.config.reranker_model, + top_n=self.config.rerank_top_n, + device=self.config.device, + ) + return self._reranker + + def _get_query_transformer(self): + """获取查询改写器""" + if self._query_transformer is None and self.llm is not None: + self._query_transformer = MultiQueryTransformer( + llm=self.llm, + num_queries=self.config.num_queries, + ) + return self._query_transformer + + def _create_basic_retriever(self): + """创建基础检索器(Level 1)""" + return create_base_retriever( + collection_name=self.config.collection_name, + embeddings=self.embeddings, + search_kwargs={"k": self.config.total_k}, + client=self._get_client(), + ) + + def _create_hybrid_retriever(self): + """创建混合检索器(Level 2)""" + base_retriever = create_hybrid_retriever( + collection_name=self.config.collection_name, + embeddings=self.embeddings, + dense_k=self.config.dense_k, + sparse_k=self.config.sparse_k, + client=self._get_client(), + ) + + # 应用重排序 + reranker = self._get_reranker() + return reranker.create_contextual_compression_retriever(base_retriever) + + def _create_fusion_retriever(self): + """创建 RAG-Fusion 检索器(Level 3)""" + if self.llm is None: + raise ValueError("Level 3 (RAG-Fusion) 需要语言模型进行查询改写") + + # 创建基础混合检索器 + base_retriever = create_hybrid_retriever( + collection_name=self.config.collection_name, + embeddings=self.embeddings, + dense_k=self.config.dense_k, + sparse_k=self.config.sparse_k, + client=self._get_client(), + ) + + # 创建 RAG-Fusion 流水线 + reranker = self._get_reranker() + return create_rag_fusion_pipeline( + base_retriever=base_retriever, + llm=self.llm, + reranker=reranker, + num_queries=self.config.num_queries, + ) + + def _get_retriever(self): + """根据配置级别获取检索器""" + if self._retriever is None: + if self.config.rag_level == RAGLevel.BASIC: + self._retriever = self._create_basic_retriever() + elif self.config.rag_level == RAGLevel.HYBRID: + self._retriever = self._create_hybrid_retriever() + elif self.config.rag_level == RAGLevel.FUSION: + self._retriever = self._create_fusion_retriever() + elif self.config.rag_level == RAGLevel.AGENTIC: + # Agentic RAG 使用 Fusion 作为基础,在 tools.py 中包装 + self._retriever = self._create_fusion_retriever() + else: + raise ValueError(f"不支持的 RAG 级别: {self.config.rag_level}") + + return self._retriever + + def retrieve( + self, + query: str, + use_cache: Optional[bool] = None, + **kwargs, + ) -> RetrievalResult: + """ + 执行检索 + + Args: + query: 查询文本 + use_cache: 是否使用缓存 + **kwargs: 额外参数 + + Returns: + 检索结果 + """ + start_time = time.time() + + # 检查缓存 + if use_cache is None: + use_cache = self.config.enable_cache + + cache_key = f"{query}:{self.config.rag_level.value}" + if use_cache and cache_key in self._cache: + if self.config.verbose: + print(f"使用缓存结果: {query}") + return self._cache[cache_key] + + # 获取检索器并执行检索 + retriever = self._get_retriever() + documents = retriever.invoke(query, **kwargs) + + # 计算查询时间 + query_time = time.time() - start_time + + # 创建结果 + result = RetrievalResult( + documents=documents, + query_time=query_time, + level=self.config.rag_level, + metadata={ + "query": query, + "collection": self.config.collection_name, + "doc_count": len(documents), + }, + ) + + # 缓存结果 + if use_cache: + self._cache[cache_key] = result + + if self.config.verbose: + print(f"检索完成: {len(documents)} 文档, 耗时: {query_time:.2f}s") + + return result + + def format_context( + self, + documents: List[Document], + max_length: Optional[int] = None, + ) -> str: + """ + 格式化检索到的文档为上下文文本 + + Args: + documents: 文档列表 + max_length: 最大长度(字符数) + + Returns: + 格式化后的上下文文本 + """ + context_parts = [] + total_length = 0 + + for i, doc in enumerate(documents): + # 提取内容和元数据 + content = doc.page_content.strip() + metadata = doc.metadata + + # 格式化文档 + doc_text = f"[文档 {i+1}]\n" + if metadata.get("source"): + doc_text += f"来源: {metadata['source']}\n" + if metadata.get("page"): + doc_text += f"页码: {metadata['page']}\n" + + doc_text += f"内容: {content}\n\n" + + # 检查长度限制 + if max_length is not None: + if total_length + len(doc_text) > max_length: + # 如果添加这个文档会超限,则截断并添加说明 + remaining = max_length - total_length + if remaining > 100: # 至少保留100字符 + doc_text = doc_text[:remaining] + "...\n\n[内容已截断]" + context_parts.append(doc_text) + break + else: + break + + context_parts.append(doc_text) + total_length += len(doc_text) + + return "".join(context_parts).strip() + + def clear_cache(self): + """清空缓存""" + self._cache.clear() + + @classmethod + def create_from_config( + cls, + embeddings: Embeddings, + llm: Optional[BaseLanguageModel] = None, + config_dict: Optional[Dict[str, Any]] = None, + ) -> "RAGPipeline": + """ + 从配置字典创建流水线 + + Args: + embeddings: 嵌入模型 + llm: 语言模型 + config_dict: 配置字典 + + Returns: + RAGPipeline 实例 + """ + config_dict = config_dict or {} + + # 创建配置对象 + config = RAGConfig( + collection_name=config_dict.get("collection_name", "documents"), + qdrant_url=config_dict.get("qdrant_url"), + qdrant_api_key=config_dict.get("qdrant_api_key"), + rag_level=RAGLevel(config_dict.get("rag_level", RAGLevel.FUSION.value)), + dense_k=config_dict.get("dense_k", 10), + sparse_k=config_dict.get("sparse_k", 10), + total_k=config_dict.get("total_k", 20), + rerank_top_n=config_dict.get("rerank_top_n", 5), + num_queries=config_dict.get("num_queries", 3), + reranker_model=config_dict.get("reranker_model", "BAAI/bge-reranker-base"), + device=config_dict.get("device"), + enable_cache=config_dict.get("enable_cache", True), + verbose=config_dict.get("verbose", True), + ) + + return cls(embeddings=embeddings, llm=llm, config=config) \ No newline at end of file diff --git a/app/rag/query_transform.py b/app/rag/query_transform.py new file mode 100644 index 0000000..652e6f4 --- /dev/null +++ b/app/rag/query_transform.py @@ -0,0 +1,193 @@ +""" +查询改写器 + +基于 MultiQueryRetriever 实现多路查询改写,扩大搜索范围。 +""" + +from typing import List, Optional, Any +from langchain.retrievers.multi_query import MultiQueryRetriever +from langchain_core.language_models import BaseLanguageModel +from langchain_core.prompts import PromptTemplate +from langchain_core.output_parsers import StrOutputParser + + +class MultiQueryTransformer: + """ + 多路查询改写器 + + 将单个查询改写成多个相关查询,用于 RAG-Fusion。 + """ + + def __init__( + self, + llm: BaseLanguageModel, + num_queries: int = 3, + prompt_template: Optional[str] = None, + ): + """ + 初始化查询改写器 + + Args: + llm: 语言模型实例 + num_queries: 生成的查询数量 + prompt_template: 提示词模板 + """ + self.llm = llm + self.num_queries = num_queries + + # 默认提示词模板 + self.prompt_template = prompt_template or """ + 你是一个专业的查询改写助手。你的任务是将用户的问题改写成 {num_queries} 个不同的版本。 + 这些版本应该从不同的角度、使用不同的关键词来表达相同或相关的意图。 + + 原始问题: {question} + + 请生成 {num_queries} 个不同版本的查询,每个版本一行。 + 确保每个版本都是独立、完整的查询语句。 + + 生成 {num_queries} 个查询: + """ + + def transform_query(self, query: str) -> List[str]: + """ + 将单个查询改写成多个查询 + + Args: + query: 原始查询 + + Returns: + 改写后的查询列表 + """ + prompt = PromptTemplate.from_template(self.prompt_template) + + chain = prompt | self.llm | StrOutputParser() + + response = chain.invoke({ + "question": query, + "num_queries": self.num_queries, + }) + + # 解析响应,每行一个查询 + queries = [ + q.strip() + for q in response.strip().split('\n') + if q.strip() + ] + + # 确保数量正确,如果不够则添加原始查询 + if len(queries) < self.num_queries: + queries.extend([query] * (self.num_queries - len(queries))) + elif len(queries) > self.num_queries: + queries = queries[:self.num_queries] + + # 确保包含原始查询 + if query not in queries: + queries = [query] + queries[:self.num_queries-1] + + return queries + + def create_multi_query_retriever( + self, + base_retriever: Any, + include_original: bool = True, + ) -> MultiQueryRetriever: + """ + 创建多路查询检索器 + + Args: + base_retriever: 基础检索器 + include_original: 是否包含原始查询 + + Returns: + MultiQueryRetriever 实例 + """ + retriever = MultiQueryRetriever.from_llm( + retriever=base_retriever, + llm=self.llm, + include_original=include_original, + ) + + # 设置生成的查询数量 + retriever.llm_chain.prompt = PromptTemplate.from_template( + "你是一个专业的查询改写助手。你的任务是将用户的问题改写成 {num_queries} 个不同的版本。\n" + "这些版本应该从不同的角度、使用不同的关键词来表达相同或相关的意图。\n\n" + "原始问题: {question}\n\n" + "请生成 {num_queries} 个不同版本的查询,每个版本一行。\n" + "确保每个版本都是独立、完整的查询语句。\n\n" + "生成 {num_queries} 个查询:" + ) + + # 修改调用参数以包含 num_queries + original_invoke = retriever.llm_chain.invoke + + def new_invoke(input_dict): + input_dict["num_queries"] = self.num_queries + return original_invoke(input_dict) + + retriever.llm_chain.invoke = new_invoke + + return retriever + + @classmethod + def create_from_config( + cls, + llm: BaseLanguageModel, + config: Optional[dict] = None, + ) -> "MultiQueryTransformer": + """ + 从配置创建查询改写器 + + Args: + llm: 语言模型实例 + config: 配置字典 + + Returns: + MultiQueryTransformer 实例 + """ + config = config or {} + return cls( + llm=llm, + num_queries=config.get("num_queries", 3), + prompt_template=config.get("prompt_template", None), + ) + + +def create_rag_fusion_pipeline( + base_retriever: Any, + llm: BaseLanguageModel, + reranker: Optional[Any] = None, + num_queries: int = 3, +) -> Any: + """ + 创建完整的 RAG-Fusion 流水线 + + Args: + base_retriever: 基础检索器 + llm: 语言模型(用于查询改写) + reranker: 重排序器(可选) + num_queries: 查询改写数量 + + Returns: + 检索器实例 + """ + # 创建多路查询改写器 + query_transformer = MultiQueryTransformer( + llm=llm, + num_queries=num_queries, + ) + + # 创建多路查询检索器 + multi_query_retriever = query_transformer.create_multi_query_retriever( + base_retriever=base_retriever, + include_original=True, + ) + + # 如果提供了重排序器,则应用重排序 + if reranker is not None: + from langchain.retrievers import ContextualCompressionRetriever + return ContextualCompressionRetriever( + base_compressor=reranker, + base_retriever=multi_query_retriever, + ) + + return multi_query_retriever \ No newline at end of file diff --git a/app/rag/requirements.txt b/app/rag/requirements.txt new file mode 100644 index 0000000..2e5c4a0 --- /dev/null +++ b/app/rag/requirements.txt @@ -0,0 +1,23 @@ +# RAG 系统依赖 +# 基础框架 +langchain>=0.1.0 +langchain-core>=0.1.0 +langchain-openai>=0.0.1 +langchain-qdrant>=0.1.0 + +# 用于 Cross-Encoder 重排序模型 +sentence-transformers>=2.2.0 + +# 用于 BM25 关键词混合检索 +rank-bm25>=0.2.2 + +# Qdrant 客户端 +qdrant-client>=1.6.0 + +# 可选的本地模型支持 +# vllm>=0.5.0 # 如果需要本地模型推理 +# transformers>=4.35.0 # 如果需要其他模型支持 + +# 开发依赖(测试用) +pytest>=7.0.0 +pytest-asyncio>=0.21.0 \ No newline at end of file diff --git a/app/rag/reranker.py b/app/rag/reranker.py new file mode 100644 index 0000000..4c7229f --- /dev/null +++ b/app/rag/reranker.py @@ -0,0 +1,141 @@ +""" +Cross-Encoder 重排序器 + +使用 sentence-transformers 加载交叉编码器模型,对检索结果进行精排。 +""" + +import os +from typing import List, Dict, Any, Optional +from langchain.retrievers.document_compressors import CrossEncoderReranker +from langchain_core.documents import Document +from sentence_transformers import CrossEncoder + + +class CrossEncoderReranker: + """ + Cross-Encoder 重排序器包装类 + + 支持 BAAI/bge-reranker-base 等模型。 + """ + + def __init__( + self, + model_name: str = "BAAI/bge-reranker-base", + top_n: int = 5, + device: Optional[str] = None, + cache_folder: Optional[str] = None, + ): + """ + 初始化重排序器 + + Args: + model_name: 模型名称或路径 + top_n: 返回的顶部文档数量 + device: 设备(cpu/cuda),如果为 None 则自动选择 + cache_folder: 模型缓存目录 + """ + self.model_name = model_name + self.top_n = top_n + self.device = device + self.cache_folder = cache_folder or os.path.join( + os.path.expanduser("~"), ".cache", "sentence_transformers" + ) + + # 延迟加载模型 + self._model = None + self._langchain_reranker = None + + def _load_model(self): + """加载交叉编码器模型""" + if self._model is None: + try: + self._model = CrossEncoder( + self.model_name, + device=self.device, + cache_folder=self.cache_folder, + ) + except Exception as e: + # 如果指定模型加载失败,尝试备用模型 + print(f"加载模型 {self.model_name} 失败: {e}") + print("尝试加载备用模型 BAAI/bge-reranker-v2-m3...") + self._model = CrossEncoder( + "BAAI/bge-reranker-v2-m3", + device=self.device, + cache_folder=self.cache_folder, + ) + + def _create_langchain_reranker(self): + """创建 LangChain 重排序器""" + if self._langchain_reranker is None: + self._load_model() + self._langchain_reranker = CrossEncoderReranker( + model=self._model, + top_n=self.top_n, + ) + + def rerank( + self, + query: str, + documents: List[Document], + ) -> List[Document]: + """ + 对文档进行重排序 + + Args: + query: 查询文本 + documents: 待排序文档列表 + + Returns: + 重排序后的文档列表 + """ + self._create_langchain_reranker() + return self._langchain_reranker.compress_documents( + documents=documents, + query=query, + ) + + def create_contextual_compression_retriever( + self, + base_retriever: Any, + ) -> Any: + """ + 创建上下文压缩检索器 + + Args: + base_retriever: 基础检索器 + + Returns: + 上下文压缩检索器 + """ + from langchain.retrievers import ContextualCompressionRetriever + + self._create_langchain_reranker() + + compression_retriever = ContextualCompressionRetriever( + base_compressor=self._langchain_reranker, + base_retriever=base_retriever, + ) + + return compression_retriever + + @classmethod + def create_from_config( + cls, + config: Optional[Dict[str, Any]] = None, + ) -> "CrossEncoderReranker": + """ + 从配置创建重排序器 + + Args: + config: 配置字典,包含 model_name, top_n, device 等 + + Returns: + CrossEncoderReranker 实例 + """ + config = config or {} + return cls( + model_name=config.get("model_name", "BAAI/bge-reranker-base"), + top_n=config.get("top_n", 5), + device=config.get("device", None), + cache_folder=config.get("cache_folder", None), + ) \ No newline at end of file diff --git a/app/rag/retriever.py b/app/rag/retriever.py new file mode 100644 index 0000000..19a7511 --- /dev/null +++ b/app/rag/retriever.py @@ -0,0 +1,144 @@ +""" +Qdrant 向量检索器 + +提供基础向量检索、混合检索(Dense + BM25)功能。 +""" + +import os +from typing import List, Dict, Any, Optional +from langchain_qdrant import Qdrant +from langchain.embeddings.base import Embeddings +from langchain.retrievers import ContextualCompressionRetriever +from langchain.retrievers.document_compressors import DocumentCompressorPipeline +from langchain.retrievers import EnsembleRetriever +from qdrant_client import QdrantClient +from qdrant_client.http import models + + +def create_qdrant_client( + url: Optional[str] = None, + api_key: Optional[str] = None, +) -> QdrantClient: + """ + 创建 Qdrant 客户端 + + Args: + url: Qdrant 服务地址,默认从环境变量 QDRANT_URL 读取 + api_key: API 密钥,默认从环境变量 QDRANT_API_KEY 读取 + + Returns: + QdrantClient 实例 + """ + url = url or os.getenv("QDRANT_URL", "http://localhost:6333") + api_key = api_key or os.getenv("QDRANT_API_KEY") + + client_args = {"url": url} + if api_key: + client_args["api_key"] = api_key + + return QdrantClient(**client_args) + + +def create_base_retriever( + collection_name: str, + embeddings: Embeddings, + search_kwargs: Optional[Dict[str, Any]] = None, + client: Optional[QdrantClient] = None, +) -> Qdrant: + """ + 创建基础向量检索器 + + Args: + collection_name: Qdrant 集合名称 + embeddings: 嵌入模型 + search_kwargs: 搜索参数,默认 {"k": 20} + client: Qdrant 客户端,如果为 None 则自动创建 + + Returns: + Qdrant 检索器实例 + """ + if client is None: + client = create_qdrant_client() + + search_kwargs = search_kwargs or {"k": 20} + + # 创建 Qdrant 检索器 + retriever = Qdrant.from_existing_collection( + embedding=embeddings, + collection_name=collection_name, + client=client, + content_payload_key="content", # 假设存储的文本字段名为 "content" + metadata_payload_key="metadata", # 元数据字段名 + ) + + return retriever.as_retriever(search_kwargs=search_kwargs) + + +def create_hybrid_retriever( + collection_name: str, + embeddings: Embeddings, + dense_k: int = 10, + sparse_k: int = 10, + client: Optional[QdrantClient] = None, +) -> ContextualCompressionRetriever: + """ + 创建混合检索器(Dense Vector + BM25) + + Args: + collection_name: Qdrant 集合名称 + embeddings: 嵌入模型 + dense_k: 向量检索返回数量 + sparse_k: BM25 检索返回数量 + client: Qdrant 客户端 + + Returns: + 混合检索器 + """ + if client is None: + client = create_qdrant_client() + + # 基础检索器(Qdrant 支持混合检索) + base_retriever = Qdrant.from_existing_collection( + embedding=embeddings, + collection_name=collection_name, + client=client, + content_payload_key="content", + metadata_payload_key="metadata", + ) + + # 配置混合检索参数 + search_kwargs = { + "k": dense_k + sparse_k, # 总返回数量 + "score_threshold": 0.3, # 相似度阈值 + } + + return base_retriever.as_retriever(search_kwargs=search_kwargs) + + +def create_ensemble_retriever( + retrievers: List[Any], + weights: Optional[List[float]] = None, + c: int = 60, +) -> EnsembleRetriever: + """ + 创建集成检索器,支持倒数排名融合 (RRF) + + Args: + retrievers: 检索器列表 + weights: 检索器权重 + c: RRF 常数(通常为60) + + Returns: + 集成检索器 + """ + if weights is None: + weights = [1.0 / len(retrievers)] * len(retrievers) + + ensemble = EnsembleRetriever( + retrievers=retrievers, + weights=weights, + c=c, + search_type="rrf", # 使用倒数排名融合 + ) + + return ensemble \ No newline at end of file diff --git a/app/rag/tools.py b/app/rag/tools.py new file mode 100644 index 0000000..8dcf90f --- /dev/null +++ b/app/rag/tools.py @@ -0,0 +1,230 @@ +""" +RAG 工具包装 + +将 RAG 流水线包装成 LangChain Tool,供 Agent 调用。 +""" + +from typing import Optional, Dict, Any +from langchain.tools import tool +from langchain_core.tools import Tool +from langchain_core.embeddings import Embeddings +from langchain_core.language_models import BaseLanguageModel + +from .pipeline import RAGPipeline, RAGConfig, RAGLevel + + +class RAGTool: + """ + RAG 工具包装器 + + 将 RAG 流水线包装成 Agent 可调用的工具。 + """ + + def __init__( + self, + pipeline: RAGPipeline, + tool_name: str = "search_knowledge_base", + tool_description: str = None, + ): + """ + 初始化 RAG 工具 + + Args: + pipeline: RAG 流水线实例 + tool_name: 工具名称 + tool_description: 工具描述 + """ + self.pipeline = pipeline + self.tool_name = tool_name + + # 默认工具描述 + self.tool_description = tool_description or ( + "在知识库中搜索相关信息。当用户询问需要查阅文档、规章制度、" + "专业知识或需要基于已知信息回答的问题时使用此工具。" + "输入应为要搜索的查询文本。" + ) + + # 创建 LangChain 工具 + self._tool = self._create_tool() + + def _create_tool(self) -> Tool: + """创建 LangChain 工具""" + + @tool(self.tool_name, args_schema=None) + def search_knowledge_base(query: str) -> str: + """ + 在知识库中搜索相关信息 + + Args: + query: 搜索查询 + + Returns: + 格式化后的搜索结果 + """ + try: + # 执行检索 + result = self.pipeline.retrieve(query) + + if not result.documents: + return "在知识库中未找到相关信息。" + + # 格式化上下文 + context = self.pipeline.format_context( + result.documents, + max_length=4000, # 限制上下文长度 + ) + + # 构建响应 + response = ( + f"🔍 在知识库中找到了 {len(result.documents)} 条相关信息:\n\n" + f"{context}\n\n" + f"⏱️ 检索耗时: {result.query_time:.2f}秒" + ) + + return response + + except Exception as e: + error_msg = f"检索过程中发生错误: {str(e)}" + if self.pipeline.config.verbose: + print(f"RAG 工具错误: {error_msg}") + return error_msg + + # 设置工具描述 + search_knowledge_base.description = self.tool_description + + return search_knowledge_base + + def get_tool(self) -> Tool: + """获取 LangChain 工具""" + return self._tool + + def __call__(self, query: str) -> str: + """直接调用工具""" + return self._tool.invoke({"query": query}) + + +def create_rag_tool( + embeddings: Embeddings, + llm: Optional[BaseLanguageModel] = None, + config: Optional[Dict[str, Any]] = None, + tool_name: str = "search_knowledge_base", + tool_description: Optional[str] = None, +) -> Tool: + """ + 创建 RAG 工具(便捷函数) + + Args: + embeddings: 嵌入模型 + llm: 语言模型(用于高级 RAG 功能) + config: RAG 配置字典 + tool_name: 工具名称 + tool_description: 工具描述 + + Returns: + LangChain Tool 实例 + """ + # 创建 RAG 流水线 + pipeline = RAGPipeline.create_from_config( + embeddings=embeddings, + llm=llm, + config_dict=config, + ) + + # 创建工具包装器 + rag_tool = RAGTool( + pipeline=pipeline, + tool_name=tool_name, + tool_description=tool_description, + ) + + return rag_tool.get_tool() + + +# 导出便捷函数 +search_knowledge_base_tool = create_rag_tool + + +def bind_rag_to_agent( + agent_llm: BaseLanguageModel, + embeddings: Embeddings, + rag_llm: Optional[BaseLanguageModel] = None, + config: Optional[Dict[str, Any]] = None, + tool_name: str = "search_knowledge_base", +) -> BaseLanguageModel: + """ + 将 RAG 工具绑定到 Agent 模型 + + Args: + agent_llm: Agent 使用的语言模型 + embeddings: 嵌入模型 + rag_llm: RAG 流水线使用的语言模型(如果与 agent_llm 不同) + config: RAG 配置 + tool_name: 工具名称 + + Returns: + 绑定工具后的模型 + """ + # 如果未指定 RAG LLM,使用 Agent LLM + if rag_llm is None: + rag_llm = agent_llm + + # 创建 RAG 工具 + rag_tool = create_rag_tool( + embeddings=embeddings, + llm=rag_llm, + config=config, + tool_name=tool_name, + ) + + # 绑定工具到模型 + return agent_llm.bind_tools([rag_tool]) + + +def create_agentic_rag_pipeline( + embeddings: Embeddings, + agent_llm: BaseLanguageModel, + rag_llm: Optional[BaseLanguageModel] = None, + config: Optional[Dict[str, Any]] = None, +) -> Dict[str, Any]: + """ + 创建完整的 Agentic RAG 流水线(Level 4) + + Args: + embeddings: 嵌入模型 + agent_llm: Agent 模型 + rag_llm: RAG 专用模型 + config: 配置 + + Returns: + 包含模型和工具的字典 + """ + # 配置 Agentic RAG 级别 + if config is None: + config = {} + config["rag_level"] = RAGLevel.AGENTIC.value + + # 创建 RAG 工具 + rag_tool = create_rag_tool( + embeddings=embeddings, + llm=rag_llm or agent_llm, + config=config, + tool_name="search_knowledge_base", + tool_description=( + "在知识库中搜索相关信息。当用户询问需要查阅文档、规章制度、" + "专业知识或需要基于已知信息回答的问题时使用此工具。" + "Agent 应该先判断是否需要使用此工具,然后调用它获取上下文。" + ), + ) + + # 绑定工具到模型 + bound_llm = agent_llm.bind_tools([rag_tool]) + + return { + "llm": bound_llm, + "tool": rag_tool, + "pipeline": RAGPipeline.create_from_config( + embeddings=embeddings, + llm=rag_llm or agent_llm, + config_dict=config, + ), + } \ No newline at end of file diff --git a/frontend/components/chat_area.py b/frontend/components/chat_area.py index 816192b..4571c01 100644 --- a/frontend/components/chat_area.py +++ b/frontend/components/chat_area.py @@ -119,6 +119,7 @@ def _handle_ai_response(): api_thought = "" display_text = "" display_thought = "" + rag_sources = None # 存储 RAG 检索来源信息 # 调用流式 API stream = api_client.chat_stream( @@ -213,6 +214,25 @@ def _handle_ai_response(): last_msg = messages_update[-1] if messages_update else {} if isinstance(last_msg, dict) and last_msg.get("role") == "tool": tool_name = last_msg.get("name", "unknown") + tool_content = last_msg.get("content", "") + # 存储 RAG 检索结果 + if tool_name == "search_knowledge_base": + # 尝试解析 tool_content,它可能是 JSON 字符串 + sources = [] + try: + if isinstance(tool_content, str): + import json + data = json.loads(tool_content) + else: + data = tool_content + # 提取来源列表 + if isinstance(data, dict) and "sources" in data: + sources = data["sources"] + else: + sources = [str(data)] + except Exception: + sources = [str(tool_content)] + rag_sources = sources tool_status_placeholder.success(f"✅ 工具 {tool_name} 执行完成") # 短暂显示后清除,保持界面清爽 import time @@ -270,6 +290,31 @@ def _handle_ai_response(): # 移除光标 message_placeholder.markdown(display_text) + # 显示 RAG 检索来源(如果有) + if rag_sources: + with st.expander("🔍 检索来源", expanded=False): + # 格式化来源列表 + if isinstance(rag_sources, list): + for i, source in enumerate(rag_sources, 1): + if isinstance(source, dict): + content = source.get("page_content", source.get("content", str(source))) + metadata = source.get("metadata", {}) + filename = metadata.get("filename", metadata.get("source", "未知文件")) + page = metadata.get("page", metadata.get("page_number", "")) + if page: + source_info = f"**来源 {i}:** {filename} (第{page}页)" + else: + source_info = f"**来源 {i}:** {filename}" + st.markdown(source_info) + # 显示内容预览(前200字符) + preview = content[:200] + "..." if len(content) > 200 else content + st.markdown(f"> {preview}") + st.markdown("---") + else: + st.markdown(f"**来源 {i}:** {str(source)}") + else: + st.markdown(str(rag_sources)) + # 拼装包含思考过程的完整内容,以便后续在历史中正确渲染 final_content = display_text if display_thought: diff --git a/requirement.txt b/requirement.txt index 4c71c89..699f33b 100644 --- a/requirement.txt +++ b/requirement.txt @@ -48,6 +48,8 @@ pydantic==2.12.5 python-dotenv==1.2.2 typing-extensions==4.15.0 +unstructured>=0.0.1 + # ============================================================================ # 注意: # 1. 此文件包含项目直接依赖的精确版本