From 933d418d77844d47ab6bfef12fbf337d7efd8e1f Mon Sep 17 00:00:00 2001 From: root <953994191@qq.com> Date: Sun, 19 Apr 2026 22:01:55 +0800 Subject: [PATCH] =?UTF-8?q?=E6=A3=80=E7=B4=A2=E5=99=A8=E9=87=8D=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 2 + app/rag/README.md | 129 ++++--- app/rag/__init__.py | 51 ++- app/rag/example.py | 159 ++++---- app/rag/pipeline.py | 387 ++++++------------- app/rag/query_transform.py | 203 ++-------- app/rag/reranker.py | 152 ++------ app/rag/retriever.py | 134 ++++--- app/rag/tools.py | 273 ++++---------- rag_core/__init__.py | 18 + {rag_indexer => rag_core}/embedders.py | 5 +- {rag_indexer => rag_core}/store/__init__.py | 2 +- {rag_indexer => rag_core}/store/factory.py | 2 +- {rag_indexer => rag_core}/store/postgres.py | 2 +- {rag_indexer => rag_core}/vector_store.py | 8 +- rag_indexer/IndexBuilder.py | 299 +++++++++++++++ rag_indexer/README.md | 297 +++++++-------- rag_indexer/__init__.py | 44 +-- rag_indexer/builder.py | 392 -------------------- rag_indexer/cli.py | 86 ++--- rag_indexer/loaders.py | 142 ++++--- rag_indexer/splitters.py | 210 +++++++++-- rag_indexer/test/reset_index.py | 80 ++++ rag_indexer/test/test_inspect_vectors.py | 63 ++++ rag_indexer/test/test_refactored.py | 83 +++++ rag_indexer/test/test_validate_index.py | 188 ++++++++++ 26 files changed, 1694 insertions(+), 1717 deletions(-) create mode 100644 rag_core/__init__.py rename {rag_indexer => rag_core}/embedders.py (94%) rename {rag_indexer => rag_core}/store/__init__.py (92%) rename {rag_indexer => rag_core}/store/factory.py (98%) rename {rag_indexer => rag_core}/store/postgres.py (99%) rename {rag_indexer => rag_core}/vector_store.py (96%) create mode 100644 rag_indexer/IndexBuilder.py delete mode 100644 rag_indexer/builder.py create mode 100644 rag_indexer/test/reset_index.py create mode 100644 rag_indexer/test/test_inspect_vectors.py create mode 100644 rag_indexer/test/test_refactored.py create mode 100644 rag_indexer/test/test_validate_index.py diff --git a/.gitignore b/.gitignore index ff0b849..9e8ca1d 100644 --- a/.gitignore +++ b/.gitignore @@ -15,6 +15,8 @@ !scripts/** !rag_indexer/ !rag_indexer/** +!rag_core/ +!rag_core/** !docker/ !docker/** !.gitea/ diff --git a/app/rag/README.md b/app/rag/README.md index 19f27f4..b0e0423 100644 --- a/app/rag/README.md +++ b/app/rag/README.md @@ -2,71 +2,44 @@ 该模块负责 RAG 系统的阶段二:**在线检索与生成**。它接收用户提问,从知识库中检索出上下文,利用各种高级策略去噪、融合,并作为增强上下文输入给大语言模型 (LLM)。 -## 📊 RAG-Fusion & 混合检索流水线示意图 - -```mermaid -graph TD - User((用户提问)) --> A[LLM 查询改写生成器] - - subgraph RAG-Fusion 核心流程 - A -->|改写为问题 1| B1[查询 1] - A -->|改写为问题 2| B2[查询 2] - A -->|原问题| B3[原始查询] - - B1 & B2 & B3 --> C[混合检索器 Hybrid Retriever
Dense Vector + BM25 Sparse] - - C --> D[多路召回结果合集 N=60条] - - D --> E{RRF 倒数排名融合去重} - end - - E -->|筛选出前 20 条| F[Cross-Encoder 重排器 Reranker] - - F -->|精细打分排序 Top 5| G[最终纯净上下文 Context] - - G --> H[将 Context 与原问题拼接输入大模型] - H --> I((LLM 生成最终回答)) -``` - ---- - ## 🎯 演进路线与算法详解 (Roadmap) ### Level 1: 基础向量搜索 (Basic Similarity Search) - **核心算法**: 近似最近邻搜索 (ANN, 常用 HNSW 算法)。将用户问题转化为向量后,计算它与库中向量的余弦相似度 (Cosine Similarity),取距离最近的 K 个块。 - **优缺点**: 速度极快。但只能捕捉“语义相似”,如果用户搜索特定专有名词、编号、订单号,纯向量检索往往会失效(产生“幻觉”匹配)。 +- **实现指南**: + - 使用 `rag_indexer.embedders.LlamaCppEmbedder` 作为嵌入模型 + - 使用 `app/rag/retriever.py` 中的 `create_base_retriever` 创建基础检索器 + - 配置 `search_kwargs={"k": 20}` 进行初步召回 ### Level 2: 混合检索与重排序 (Hybrid Search + Reranker) 混合检索旨在结合向量的“语义泛化”与关键词的“精准匹配”,随后利用重排序模型过滤噪声。 **1. 基础召回 (混合检索)** - **核心原理**: 结合基于 HNSW 的 Dense Vector 相似度搜索与基于 TF-IDF 的 BM25 稀疏检索 (Sparse Vector)。 -- **实现指南**: 使用 `langchain_qdrant` 包中的 `Qdrant` 类连接数据库。通过调用 `Qdrant.from_existing_collection(...)` 实例化向量库,并使用 `.as_retriever(search_kwargs={"k": 20})` 方法生成基础检索器。Qdrant 底层会自动处理双路召回。 +- **实现指南**: 使用 `app/rag/retriever.py` 中的 `create_hybrid_retriever` 函数,配置 `dense_k=10` 和 `sparse_k=10`,总召回 20 条结果。 **2. 二次精排 (Cross-Encoder)** - **核心原理**: 不同于双塔模型(分别算向量再求距离),交叉编码器将“用户问题 + 检索到的单例文档”拼接后整体输入 Transformer 模型,由模型直接输出 0~1 的相关性得分,精度极高。 -- **实现指南**: - - 使用 `sentence-transformers` 库加载本地轻量级重排模型(如 `BAAI/bge-reranker-base`)。 - - 引入 `langchain.retrievers.document_compressors` 包中的 `CrossEncoderReranker` 类包装该模型,设置参数 `top_n=5`。 - - 最后,使用 `langchain.retrievers` 包中的 `ContextualCompressionRetriever` 类,将 `base_compressor` (重排器) 和 `base_retriever` (基础检索器) 组合。 - - **如何调用**: 业务逻辑中直接对组合后的检索器调用 `.invoke(query)` 方法,即可一键完成“大范围召回 20 条 -> 逐一打分精排选 5 条”的去噪流水线。 +- **实现指南**: + - 使用 `app/rag/reranker.py` 中的 `CrossEncoderReranker` 类,加载 `BAAI/bge-reranker-base` 模型 + - 设置 `top_n=5` 保留最相关的 5 条结果 + - 使用 `ContextualCompressionRetriever` 组合基础检索器和重排序器 ### Level 3: RAG-Fusion (多路改写与倒数排名融合) RAG-Fusion 通过大模型发散思维,将单一问题改写为多个相似问题,扩大搜索面,再利用数学统计算法合并结果。 **1. 多路查询改写** - **核心原理**: 克服用户初始提问词不达意或视角受限的问题。 -- **实现指南**: 导入 `langchain.retrievers.multi_query` 包中的 `MultiQueryRetriever` 类。需向其提供一个已实例化的 LLM 对象(如基于 `ChatOpenAI` 封装的本地 VLLM 模型)。系统在底层会自动 Prompt 模型,将原始 `query` 转化为包含 3-5 个不同表述的查询列表。 +- **实现指南**: 使用 `app/rag/query_transform.py` 中的 `MultiQueryTransformer` 类,配置 `num_queries=3` 生成 3 个不同角度的查询。 **2. 倒数排名融合 (RRF)** - **核心原理**: RRF (Reciprocal Rank Fusion) 是一种无需评分归一化的融合算法。公式为 $RRF\_score(d) = \sum_{q \in Q} \frac{1}{k + rank_q(d)}$,有效避免某一极端检索结果主导全局。 -- **实现指南**: - - 针对每个改写后的查询 $q$,分别调用精排检索器的 `.invoke(q)` 获取文档列表。 - - 使用 `langchain.retrievers` 中的 `EnsembleRetriever` 类(原生支持 RRF),或在代码中遍历收集到的 `Document` 对象,基于其排名 `rank` 累加得分,最终通过 Python 的 `set` 去重并提取 `doc.page_content`。 +- **实现指南**: 使用 `app/rag/retriever.py` 中的 `create_ensemble_retriever` 函数,配置 `search_type="rrf"` 实现倒数排名融合。 ### Level 4: Agentic RAG / Self-RAG (智能体与自我反思) - **核心原理**: 基于 LangGraph 的 ReAct (Reasoning and Acting) 状态机路由。大模型并非每次都去死板地执行检索,而是先判断问题:“这是闲聊?还是需要查知识库?”。如果是后者,模型输出一个 `ToolCall` 指令,触发检索。 -- **实现指南**: 请参考下方的**与现有系统整合调用**章节。 +- **实现指南**: 使用 `app/rag/tools.py` 中的 `search_knowledge_base` 工具,将其绑定到 LangGraph 状态机中。 - **示意图**: ```mermaid @@ -87,6 +60,13 @@ RAG-Fusion 通过大模型发散思维,将单一问题改写为多个相似问 LangGraph Agent-->>User: "根据知识库规定,报销流程分为以下3步..." ``` +### Level 5: GraphRAG 集成 (基于图和关系的 RAG) +- **核心原理**: 结合知识图谱的结构化关系和向量检索的语义相似度,解决跨文档复杂关系推理问题。 +- **实现指南**: + - 使用 `langchain_community.graphs` 模块构建知识图谱 + - 配置本地大模型(如 `Gemma-4-E2B`)用于实体关系抽取 + - 实现混合检索逻辑,结合向量相似度和图路径分析 + --- ## 📦 所需依赖与安装 @@ -102,18 +82,19 @@ pip install rank_bm25 # 基础框架 pip install langchain langchain-core langchain-openai langchain-qdrant + +# 与 rag_indexer 共享的依赖 +pip install qdrant-client httpx ``` --- ## 📂 架构与文件结构设计 -在 `app/rag/` 目录下,需创建以下文件来模块化上述功能: - -```text +``` app/rag/ ├── __init__.py -├── retriever.py # 负责 Qdrant 的基础召回与 ContextualCompressionRetriever +├── retriever.py # 负责 Qdrant 的基础召回与混合检索 ├── reranker.py # 负责加载 sentence-transformers 交叉编码器 ├── query_transform.py # 负责基于 MultiQueryRetriever 的改写逻辑 ├── pipeline.py # 组合上述组件,暴露出核心的 retrieve() 方法 @@ -122,15 +103,69 @@ app/rag/ --- -## � 与现有系统整合调用 (Agentic RAG 实现) +## 🔄 与 rag_indexer 集成 + +### 数据结构兼容性 +- **向量存储**: rag_indexer 使用 Qdrant 存储子块向量,app/rag 直接从相同集合读取 +- **文档存储**: rag_indexer 使用 PostgreSQL 存储父块,app/rag 通过 `ParentDocumentRetriever` 关联 +- **嵌入模型**: 共享 `LlamaCppEmbedder` 确保向量空间一致性 + +### 配置共享 +- **环境变量**: QDRANT_URL、QDRANT_API_KEY、DB_URI 等配置在两个模块间共享 +- **集合名称**: 默认使用 "rag_documents" 集合,确保数据一致性 + +--- + +## 🚀 与现有系统整合调用 (Agentic RAG 实现) 基于目前 LangGraph 系统的架构,我们将摒弃将代码堆砌在一起的旧方式,而是利用 **LangChain Tools** 的特性将 RAG 优雅地注入系统: -1. **封装检索工具 (Tool)**: +1. **封装检索工具 (Tool)**: 从 `langchain.tools` 导入 `@tool` 装饰器。定义一个名为 `search_knowledge_base(query: str)` 的函数。在函数内部,实例化并调用我们在 `pipeline.py` 中写好的多路召回与重排逻辑。 -2. **模型绑定 (Bind)**: + +2. **模型绑定 (Bind)**: 在 `app/agent.py` 或 `app/nodes/tool_call.py` 中,将这个工具引入,并通过 `llm.bind_tools([search_knowledge_base])` 绑定到现有的本地大模型实例上。 -3. **状态机路由 (Graph Routing)**: + +3. **状态机路由 (Graph Routing)**: 你的 LangGraph 状态机会像处理普通对话一样自动接管:当模型判断需要调用查阅规章制度或专业资料时,它会输出 `ToolCall` 消息,流转到 `tool_node` 执行上述的 RAG 检索逻辑并返回上下文。 这让你无需修改任何前端 Streamlit 流式代码,就能平滑升级为具备超级知识库检索能力的智能体 (Agent)! + +--- + +## 🎯 快速开始 + +```python +# 1. 初始化嵌入模型 +from rag_indexer.embedders import LlamaCppEmbedder +embeddings = LlamaCppEmbedder() + +# 2. 初始化语言模型(用于 RAG-Fusion) +from langchain_openai import OpenAI +llm = OpenAI( + openai_api_base="http://localhost:8000/v1", + openai_api_key="no-key-needed", + model_name="Qwen2.5-7B-Instruct", + temperature=0.3, +) + +# 3. 创建 RAG 流水线 +from app.rag.pipeline import RAGPipeline, RAGLevel +pipeline = RAGPipeline( + embeddings=embeddings, + llm=llm, + config={ + "collection_name": "rag_documents", + "rag_level": RAGLevel.FUSION.value, + "num_queries": 3, + "rerank_top_n": 5, + }, +) + +# 4. 执行检索 +result = pipeline.retrieve("如何申请项目资金?") + +# 5. 格式化上下文 +context = pipeline.format_context(result.documents) +print(context) +``` diff --git a/app/rag/__init__.py b/app/rag/__init__.py index 0d44285..623bb8f 100644 --- a/app/rag/__init__.py +++ b/app/rag/__init__.py @@ -1,22 +1,53 @@ """ -在线 RAG 检索与生成系统 +RAG 检索与生成模块 -提供高级RAG检索功能,支持混合检索、重排序、RAG-Fusion和多路查询改写。 +提供在线检索与生成功能,包括: +- 基础向量检索 +- 重排序 +- RAG-Fusion +- Agentic RAG + +示例用法: + >>> from app.rag import RAGPipeline, search_knowledge_base + >>> from rag_core import LlamaCppEmbedder + >>> + >>> embeddings = LlamaCppEmbedder() + >>> pipeline = RAGPipeline(embeddings=embeddings) + >>> + >>> documents = pipeline.retrieve("戏耍貂蝉美女") + >>> context = pipeline.format_context(documents) """ -from .pipeline import RAGPipeline -from .retriever import create_hybrid_retriever, create_base_retriever +from .retriever import ( + create_base_retriever, + create_hybrid_retriever, + # create_ensemble_retriever, + create_qdrant_client, +) from .reranker import CrossEncoderReranker from .query_transform import MultiQueryTransformer -from .tools import search_knowledge_base_tool +from .pipeline import RAGPipeline, RAGLevel +from .tools import search_knowledge_base, search_knowledge_base_sync + __all__ = [ - "RAGPipeline", - "create_hybrid_retriever", + # 检索器 "create_base_retriever", + "create_hybrid_retriever", + # "create_ensemble_retriever", + "create_qdrant_client", + + # 重排序器 "CrossEncoderReranker", + + # 查询转换器 "MultiQueryTransformer", - "search_knowledge_base_tool", + + # 流水线 + "RAGPipeline", + "RAGLevel", + + # 工具 + "search_knowledge_base", + "search_knowledge_base_sync", ] - -__version__ = "0.1.0" \ No newline at end of file diff --git a/app/rag/example.py b/app/rag/example.py index 53d82fe..8042b09 100644 --- a/app/rag/example.py +++ b/app/rag/example.py @@ -7,6 +7,10 @@ RAG 系统使用示例 import sys import os +from dotenv import load_dotenv + +# 加载环境变量 +load_dotenv() # 添加项目根目录到路径 sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) @@ -19,10 +23,13 @@ def setup_environment(): """设置环境变量""" # 设置 Qdrant 连接信息(根据实际情况修改) os.environ.setdefault("QDRANT_URL", "http://115.190.121.151:6333") + # 设置 Qdrant API 密钥(根据实际情况修改) + os.environ.setdefault("QDRANT_API_KEY", "your-api-key-here") # 如果需要 API 密钥,请设置 QDRANT_API_KEY print("环境变量已设置") print(f"QDRANT_URL: {os.environ.get('QDRANT_URL')}") + print(f"QDRANT_API_KEY: {'***' if os.environ.get('QDRANT_API_KEY') else '未设置'}") def demonstrate_basic_rag(): @@ -31,37 +38,32 @@ def demonstrate_basic_rag(): print("演示: 基础 RAG 检索 (Level 1)") print("="*60) - # 创建嵌入模型(使用 OpenAI 兼容的本地模型) - embeddings = OpenAIEmbeddings( - openai_api_base="http://localhost:8000/v1", # 本地 VLLM 服务 - openai_api_key="no-key-needed", - model="text-embedding-ada-002", # 假设的模型名称 - ) + # 创建嵌入模型(使用本地 LlamaCpp 模型) + from rag_core import LlamaCppEmbedder + embedder = LlamaCppEmbedder() + embeddings = embedder.as_langchain_embeddings() # 创建 RAG 流水线 - from app.rag import RAGPipeline, RAGConfig, RAGLevel - - config = RAGConfig( - collection_name="documents", # 你的集合名称 - rag_level=RAGLevel.BASIC, - verbose=True, - ) + from app.rag import RAGPipeline, RAGLevel pipeline = RAGPipeline( embeddings=embeddings, - config=config, + config={ + "collection_name": "rag_documents", # 你的集合名称 + "rag_level": RAGLevel.BASIC.value, + } ) # 示例查询 - query = "公司报销流程是什么?" + query = "吕布" print(f"\n查询: {query}") try: - result = pipeline.retrieve(query) - print(f"找到 {len(result.documents)} 个相关文档") + documents = pipeline.retrieve(query) + print(f"找到 {len(documents)} 个相关文档") # 格式化上下文 - context = pipeline.format_context(result.documents) + context = pipeline.format_context(documents) print(f"\n上下文预览:\n{context[:500]}...") except Exception as e: @@ -75,34 +77,31 @@ def demonstrate_hybrid_rag(): print("演示: 混合 RAG 检索 (Level 2)") print("="*60) - embeddings = OpenAIEmbeddings( - openai_api_base="http://localhost:8000/v1", - openai_api_key="no-key-needed", - model="text-embedding-ada-002", - ) + from rag_core import LlamaCppEmbedder + embedder = LlamaCppEmbedder() + embeddings = embedder.as_langchain_embeddings() - from app.rag import RAGPipeline, RAGConfig, RAGLevel - - config = RAGConfig( - collection_name="documents", - rag_level=RAGLevel.HYBRID, - dense_k=10, - sparse_k=10, - rerank_top_n=5, - verbose=True, - ) + from app.rag import RAGPipeline, RAGLevel pipeline = RAGPipeline( embeddings=embeddings, - config=config, + config={ + "collection_name": "rag_documents", + "rag_level": RAGLevel.RERANK.value, + "rerank_top_n": 5, + } ) - query = "如何申请年假?" + query = "吕布" print(f"\n查询: {query}") try: - result = pipeline.retrieve(query) - print(f"找到 {len(result.documents)} 个重排序后的文档") + documents = pipeline.retrieve(query) + print(f"找到 {len(documents)} 个重排序后的文档") + + # 格式化上下文 + context = pipeline.format_context(documents) + print(f"\n上下文预览:\n{context[:500]}...") except Exception as e: print(f"检索失败: {e}") @@ -114,42 +113,42 @@ def demonstrate_rag_fusion(): print("演示: RAG-Fusion (Level 3)") print("="*60) - embeddings = OpenAIEmbeddings( - openai_api_base="http://localhost:8000/v1", - openai_api_key="no-key-needed", - model="text-embedding-ada-002", - ) + from rag_core import LlamaCppEmbedder + embedder = LlamaCppEmbedder() + embeddings = embedder.as_langchain_embeddings() - # 创建语言模型用于查询改写 - llm = VLLMOpenAI( + # 创建语言模型用于查询改写(使用 OpenAI 兼容的本地模型) + from langchain_openai import ChatOpenAI + llm = ChatOpenAI( openai_api_base="http://localhost:8000/v1", openai_api_key="no-key-needed", - model_name="Qwen2.5-7B-Instruct", # 你的本地模型 + model="Qwen2.5-7B-Instruct", # 你的本地模型 temperature=0.3, max_tokens=512, ) - from app.rag import RAGPipeline, RAGConfig, RAGLevel - - config = RAGConfig( - collection_name="documents", - rag_level=RAGLevel.FUSION, - num_queries=3, - verbose=True, - ) + from app.rag import RAGPipeline, RAGLevel pipeline = RAGPipeline( embeddings=embeddings, llm=llm, - config=config, + config={ + "collection_name": "rag_documents", + "rag_level": RAGLevel.FUSION.value, + "num_queries": 3, + } ) - query = "项目上线需要哪些审批?" + query = "吕布" print(f"\n查询: {query}") try: - result = pipeline.retrieve(query) - print(f"找到 {len(result.documents)} 个文档 (经过多路查询改写和重排序)") + documents = pipeline.retrieve(query) + print(f"找到 {len(documents)} 个文档 (经过多路查询改写和重排序)") + + # 格式化上下文 + context = pipeline.format_context(documents) + print(f"\n上下文预览:\n{context[:500]}...") except Exception as e: print(f"检索失败: {e}") @@ -161,44 +160,16 @@ def demonstrate_agentic_rag(): print("演示: Agentic RAG (Level 4)") print("="*60) - embeddings = OpenAIEmbeddings( - openai_api_base="http://localhost:8000/v1", - openai_api_key="no-key-needed", - model="text-embedding-ada-002", - ) - - llm = VLLMOpenAI( - openai_api_base="http://localhost:8000/v1", - openai_api_key="no-key-needed", - model_name="Qwen2.5-7B-Instruct", - temperature=0.3, - max_tokens=512, - ) - - from app.rag import create_agentic_rag_pipeline + from app.rag import search_knowledge_base_sync try: - # 创建 Agentic RAG 流水线 - agentic_rag = create_agentic_rag_pipeline( - embeddings=embeddings, - agent_llm=llm, - config={ - "collection_name": "documents", - "verbose": True, - }, - ) - - print("Agentic RAG 流水线创建成功!") - print(f"- 绑定的模型: {agentic_rag['llm']}") - print(f"- RAG 工具: {agentic_rag['tool'].name}") - # 演示工具调用 - print("\n工具调用示例:") - response = agentic_rag["tool"].invoke({"query": "员工福利有哪些?"}) + print("工具调用示例:") + response = search_knowledge_base_sync("吕布") print(f"工具响应预览: {response[:200]}...") except Exception as e: - print(f"创建 Agentic RAG 失败: {e}") + print(f"工具调用失败: {e}") import traceback traceback.print_exc() @@ -211,11 +182,11 @@ def main(): # 设置环境 setup_environment() - # 演示各级功能 + # 演示基础功能 demonstrate_basic_rag() demonstrate_hybrid_rag() - demonstrate_rag_fusion() - demonstrate_agentic_rag() + # demonstrate_rag_fusion() # 需要本地 LLM 服务 + # demonstrate_agentic_rag() # 需要本地 LLM 服务 print("\n" + "="*60) print("演示完成!") @@ -223,8 +194,8 @@ def main(): print("\n使用说明:") print("1. 确保 Qdrant 服务运行且集合已创建") - print("2. 根据需要修改 embeddings 和 llm 配置") - print("3. 在 Agent 系统中导入并使用 app.rag.tools.search_knowledge_base_tool") + print("2. 已使用本地 LlamaCpp 嵌入模型") + print("3. 在 Agent 系统中导入并使用 app.rag.tools.search_knowledge_base") print("4. 将工具绑定到你的 Agent 模型") diff --git a/app/rag/pipeline.py b/app/rag/pipeline.py index 0e6d0bc..e5eba7c 100644 --- a/app/rag/pipeline.py +++ b/app/rag/pipeline.py @@ -1,341 +1,168 @@ """ RAG 检索流水线 -组合检索器、重排序器、查询改写器等组件,提供完整的 RAG 检索功能。 +整合基础检索、重排序和 RAG-Fusion 功能。 """ -import time -from typing import List, Dict, Any, Optional, Union -from dataclasses import dataclass, field from enum import Enum +from typing import List, Optional, Dict, Any from langchain_core.documents import Document -from langchain_core.embeddings import Embeddings from langchain_core.language_models import BaseLanguageModel from .retriever import ( create_base_retriever, create_hybrid_retriever, - create_ensemble_retriever, create_qdrant_client, ) from .reranker import CrossEncoderReranker -from .query_transform import MultiQueryTransformer, create_rag_fusion_pipeline +from .query_transform import MultiQueryTransformer +from rag_core import QDRANT_URL, QDRANT_API_KEY class RAGLevel(Enum): - """RAG 功能级别""" - BASIC = 1 # 基础向量搜索 - HYBRID = 2 # 混合检索 + 重排序 - FUSION = 3 # RAG-Fusion - AGENTIC = 4 # Agentic RAG - - -@dataclass -class RAGConfig: - """RAG 配置""" - # Qdrant 配置 - collection_name: str = "documents" - qdrant_url: Optional[str] = None - qdrant_api_key: Optional[str] = None - - # 检索配置 - rag_level: RAGLevel = RAGLevel.FUSION - dense_k: int = 10 # 向量检索数量 - sparse_k: int = 10 # BM25 检索数量 - total_k: int = 20 # 总检索数量 - rerank_top_n: int = 5 # 重排序返回数量 - - # 查询改写配置 - num_queries: int = 3 # RAG-Fusion 查询数量 - - # 模型配置 - reranker_model: str = "BAAI/bge-reranker-base" - device: Optional[str] = None - - # 性能配置 - enable_cache: bool = True - verbose: bool = True - - -@dataclass -class RetrievalResult: - """检索结果""" - documents: List[Document] - query_time: float - level: RAGLevel - metadata: Dict[str, Any] = field(default_factory=dict) + """RAG 级别""" + BASIC = "basic" # 基础向量检索 + RERANK = "rerank" # 基础检索 + 重排序 + FUSION = "fusion" # RAG-Fusion(多路查询 + RRF) class RAGPipeline: - """ - RAG 检索流水线 - - 支持从 Level 1 到 Level 4 的所有功能。 - """ + """RAG 检索流水线""" def __init__( self, - embeddings: Embeddings, + embeddings, llm: Optional[BaseLanguageModel] = None, - config: Optional[RAGConfig] = None, + config: Optional[Dict[str, Any]] = None, ): """ 初始化 RAG 流水线 Args: embeddings: 嵌入模型 - llm: 语言模型(用于查询改写,Level 3+ 需要) - config: 配置 + llm: 语言模型(用于 RAG-Fusion) + config: 配置参数 """ self.embeddings = embeddings self.llm = llm - self.config = config or RAGConfig() + self.config = config or {} - # 初始化组件 - self._client = None - self._reranker = None - self._query_transformer = None - self._retriever = None + self.collection_name = self.config.get("collection_name", "rag_documents") + self.rag_level = self.config.get("rag_level", RAGLevel.RERANK.value) + self.num_queries = self.config.get("num_queries", 3) + self.rerank_top_n = self.config.get("rerank_top_n", 5) - # 缓存 - self._cache = {} - - def _get_client(self): - """获取 Qdrant 客户端""" - if self._client is None: - self._client = create_qdrant_client( - url=self.config.qdrant_url, - api_key=self.config.qdrant_api_key, - ) - return self._client - - def _get_reranker(self): - """获取重排序器""" - if self._reranker is None: - self._reranker = CrossEncoderReranker( - model_name=self.config.reranker_model, - top_n=self.config.rerank_top_n, - device=self.config.device, - ) - return self._reranker - - def _get_query_transformer(self): - """获取查询改写器""" - if self._query_transformer is None and self.llm is not None: - self._query_transformer = MultiQueryTransformer( - llm=self.llm, - num_queries=self.config.num_queries, - ) - return self._query_transformer - - def _create_basic_retriever(self): - """创建基础检索器(Level 1)""" - return create_base_retriever( - collection_name=self.config.collection_name, + # 初始化基础检索器 + self.base_retriever = create_base_retriever( + collection_name=self.collection_name, embeddings=self.embeddings, - search_kwargs={"k": self.config.total_k}, - client=self._get_client(), - ) - - def _create_hybrid_retriever(self): - """创建混合检索器(Level 2)""" - base_retriever = create_hybrid_retriever( - collection_name=self.config.collection_name, - embeddings=self.embeddings, - dense_k=self.config.dense_k, - sparse_k=self.config.sparse_k, - client=self._get_client(), + search_kwargs={"k": 20}, # 召回 20 条 ) - # 应用重排序 - reranker = self._get_reranker() - return reranker.create_contextual_compression_retriever(base_retriever) - - def _create_fusion_retriever(self): - """创建 RAG-Fusion 检索器(Level 3)""" - if self.llm is None: - raise ValueError("Level 3 (RAG-Fusion) 需要语言模型进行查询改写") + # 初始化重排序器 + try: + self.reranker = CrossEncoderReranker(top_n=self.rerank_top_n) + except Exception as e: + print(f"警告: 无法创建重排序器,将使用基础检索。错误: {e}") + self.reranker = None - # 创建基础混合检索器 - base_retriever = create_hybrid_retriever( - collection_name=self.config.collection_name, - embeddings=self.embeddings, - dense_k=self.config.dense_k, - sparse_k=self.config.sparse_k, - client=self._get_client(), - ) - - # 创建 RAG-Fusion 流水线 - reranker = self._get_reranker() - return create_rag_fusion_pipeline( - base_retriever=base_retriever, - llm=self.llm, - reranker=reranker, - num_queries=self.config.num_queries, - ) + # 根据 RAG 级别创建检索器 + self.retriever = self._create_retriever() - def _get_retriever(self): - """根据配置级别获取检索器""" - if self._retriever is None: - if self.config.rag_level == RAGLevel.BASIC: - self._retriever = self._create_basic_retriever() - elif self.config.rag_level == RAGLevel.HYBRID: - self._retriever = self._create_hybrid_retriever() - elif self.config.rag_level == RAGLevel.FUSION: - self._retriever = self._create_fusion_retriever() - elif self.config.rag_level == RAGLevel.AGENTIC: - # Agentic RAG 使用 Fusion 作为基础,在 tools.py 中包装 - self._retriever = self._create_fusion_retriever() + def _create_retriever(self): + """根据 RAG 级别创建检索器""" + if self.rag_level == RAGLevel.BASIC.value: + return self.base_retriever + + # 基础检索 + 重排序 + def rerank_retriever(query): + documents = self.base_retriever.invoke(query) + if self.reranker: + return self.reranker.compress_documents(documents, query) else: - raise ValueError(f"不支持的 RAG 级别: {self.config.rag_level}") + return documents[:self.rerank_top_n] - return self._retriever + if self.rag_level == RAGLevel.RERANK.value: + return SimpleRetriever(rerank_retriever) + + # RAG-Fusion + if self.rag_level == RAGLevel.FUSION.value: + if not self.llm: + raise ValueError("RAG-Fusion 需要提供 llm 参数") + + # 创建多路查询检索器 + transformer = MultiQueryTransformer( + llm=self.llm, + num_queries=self.num_queries + ) + multi_query_retriever = transformer.create_multi_query_retriever( + base_retriever=SimpleRetriever(rerank_retriever) + ) + + return multi_query_retriever + + return SimpleRetriever(rerank_retriever) - def retrieve( - self, - query: str, - use_cache: Optional[bool] = None, - **kwargs, - ) -> RetrievalResult: + def retrieve(self, query: str) -> List[Document]: """ 执行检索 Args: - query: 查询文本 - use_cache: 是否使用缓存 - **kwargs: 额外参数 - + query: 查询字符串 + Returns: - 检索结果 + 相关文档列表 """ - start_time = time.time() - - # 检查缓存 - if use_cache is None: - use_cache = self.config.enable_cache - - cache_key = f"{query}:{self.config.rag_level.value}" - if use_cache and cache_key in self._cache: - if self.config.verbose: - print(f"使用缓存结果: {query}") - return self._cache[cache_key] - - # 获取检索器并执行检索 - retriever = self._get_retriever() - documents = retriever.invoke(query, **kwargs) - - # 计算查询时间 - query_time = time.time() - start_time - - # 创建结果 - result = RetrievalResult( - documents=documents, - query_time=query_time, - level=self.config.rag_level, - metadata={ - "query": query, - "collection": self.config.collection_name, - "doc_count": len(documents), - }, - ) - - # 缓存结果 - if use_cache: - self._cache[cache_key] = result - - if self.config.verbose: - print(f"检索完成: {len(documents)} 文档, 耗时: {query_time:.2f}s") - - return result + return self.retriever.invoke(query) - def format_context( - self, - documents: List[Document], - max_length: Optional[int] = None, - ) -> str: + async def aretrieve(self, query: str) -> List[Document]: """ - 格式化检索到的文档为上下文文本 + 异步执行检索 + + Args: + query: 查询字符串 + + Returns: + 相关文档列表 + """ + return await self.retriever.ainvoke(query) + + def format_context(self, documents: List[Document]) -> str: + """ + 格式化上下文 Args: documents: 文档列表 - max_length: 最大长度(字符数) - + Returns: - 格式化后的上下文文本 + 格式化后的上下文字符串 """ + if not documents: + return "" + context_parts = [] - total_length = 0 + for i, doc in enumerate(documents, 1): + content = doc.page_content + metadata = doc.metadata or {} + source = metadata.get("source", "未知来源") + + part = f"【资料 {i}】\n" + part += f"来源: {source}\n" + part += f"内容: {content}\n" + part += "---\n" + context_parts.append(part) - for i, doc in enumerate(documents): - # 提取内容和元数据 - content = doc.page_content.strip() - metadata = doc.metadata - - # 格式化文档 - doc_text = f"[文档 {i+1}]\n" - if metadata.get("source"): - doc_text += f"来源: {metadata['source']}\n" - if metadata.get("page"): - doc_text += f"页码: {metadata['page']}\n" - - doc_text += f"内容: {content}\n\n" - - # 检查长度限制 - if max_length is not None: - if total_length + len(doc_text) > max_length: - # 如果添加这个文档会超限,则截断并添加说明 - remaining = max_length - total_length - if remaining > 100: # 至少保留100字符 - doc_text = doc_text[:remaining] + "...\n\n[内容已截断]" - context_parts.append(doc_text) - break - else: - break - - context_parts.append(doc_text) - total_length += len(doc_text) - - return "".join(context_parts).strip() + return "".join(context_parts) + + +class SimpleRetriever: + """简单检索器包装类""" - def clear_cache(self): - """清空缓存""" - self._cache.clear() + def __init__(self, retrieve_func): + self.retrieve_func = retrieve_func - @classmethod - def create_from_config( - cls, - embeddings: Embeddings, - llm: Optional[BaseLanguageModel] = None, - config_dict: Optional[Dict[str, Any]] = None, - ) -> "RAGPipeline": - """ - 从配置字典创建流水线 - - Args: - embeddings: 嵌入模型 - llm: 语言模型 - config_dict: 配置字典 - - Returns: - RAGPipeline 实例 - """ - config_dict = config_dict or {} - - # 创建配置对象 - config = RAGConfig( - collection_name=config_dict.get("collection_name", "documents"), - qdrant_url=config_dict.get("qdrant_url"), - qdrant_api_key=config_dict.get("qdrant_api_key"), - rag_level=RAGLevel(config_dict.get("rag_level", RAGLevel.FUSION.value)), - dense_k=config_dict.get("dense_k", 10), - sparse_k=config_dict.get("sparse_k", 10), - total_k=config_dict.get("total_k", 20), - rerank_top_n=config_dict.get("rerank_top_n", 5), - num_queries=config_dict.get("num_queries", 3), - reranker_model=config_dict.get("reranker_model", "BAAI/bge-reranker-base"), - device=config_dict.get("device"), - enable_cache=config_dict.get("enable_cache", True), - verbose=config_dict.get("verbose", True), - ) - - return cls(embeddings=embeddings, llm=llm, config=config) \ No newline at end of file + def invoke(self, query): + return self.retrieve_func(query) + + async def ainvoke(self, query): + return self.retrieve_func(query) diff --git a/app/rag/query_transform.py b/app/rag/query_transform.py index 652e6f4..5183f9e 100644 --- a/app/rag/query_transform.py +++ b/app/rag/query_transform.py @@ -1,193 +1,62 @@ """ -查询改写器 +查询转换器模块 -基于 MultiQueryRetriever 实现多路查询改写,扩大搜索范围。 +实现多路查询改写功能,用于 RAG-Fusion。 """ -from typing import List, Optional, Any -from langchain.retrievers.multi_query import MultiQueryRetriever +from typing import List, Optional from langchain_core.language_models import BaseLanguageModel +# from langchain.retrievers.multi_query import MultiQueryRetriever from langchain_core.prompts import PromptTemplate -from langchain_core.output_parsers import StrOutputParser class MultiQueryTransformer: - """ - 多路查询改写器 + """多路查询改写器,用于 RAG-Fusion。""" - 将单个查询改写成多个相关查询,用于 RAG-Fusion。 - """ - - def __init__( - self, - llm: BaseLanguageModel, - num_queries: int = 3, - prompt_template: Optional[str] = None, - ): + def __init__(self, llm: BaseLanguageModel, num_queries: int = 3): """ - 初始化查询改写器 + 初始化多路查询改写器。 Args: llm: 语言模型实例 num_queries: 生成的查询数量 - prompt_template: 提示词模板 """ self.llm = llm self.num_queries = num_queries - - # 默认提示词模板 - self.prompt_template = prompt_template or """ - 你是一个专业的查询改写助手。你的任务是将用户的问题改写成 {num_queries} 个不同的版本。 - 这些版本应该从不同的角度、使用不同的关键词来表达相同或相关的意图。 - - 原始问题: {question} - - 请生成 {num_queries} 个不同版本的查询,每个版本一行。 - 确保每个版本都是独立、完整的查询语句。 - - 生成 {num_queries} 个查询: - """ - def transform_query(self, query: str) -> List[str]: + def create_multi_query_retriever(self, base_retriever): """ - 将单个查询改写成多个查询 - - Args: - query: 原始查询 - - Returns: - 改写后的查询列表 - """ - prompt = PromptTemplate.from_template(self.prompt_template) - - chain = prompt | self.llm | StrOutputParser() - - response = chain.invoke({ - "question": query, - "num_queries": self.num_queries, - }) - - # 解析响应,每行一个查询 - queries = [ - q.strip() - for q in response.strip().split('\n') - if q.strip() - ] - - # 确保数量正确,如果不够则添加原始查询 - if len(queries) < self.num_queries: - queries.extend([query] * (self.num_queries - len(queries))) - elif len(queries) > self.num_queries: - queries = queries[:self.num_queries] - - # 确保包含原始查询 - if query not in queries: - queries = [query] + queries[:self.num_queries-1] - - return queries - - def create_multi_query_retriever( - self, - base_retriever: Any, - include_original: bool = True, - ) -> MultiQueryRetriever: - """ - 创建多路查询检索器 + 创建多路查询检索器。 Args: base_retriever: 基础检索器 - include_original: 是否包含原始查询 - + Returns: MultiQueryRetriever 实例 """ - retriever = MultiQueryRetriever.from_llm( - retriever=base_retriever, - llm=self.llm, - include_original=include_original, - ) - - # 设置生成的查询数量 - retriever.llm_chain.prompt = PromptTemplate.from_template( - "你是一个专业的查询改写助手。你的任务是将用户的问题改写成 {num_queries} 个不同的版本。\n" - "这些版本应该从不同的角度、使用不同的关键词来表达相同或相关的意图。\n\n" - "原始问题: {question}\n\n" - "请生成 {num_queries} 个不同版本的查询,每个版本一行。\n" - "确保每个版本都是独立、完整的查询语句。\n\n" - "生成 {num_queries} 个查询:" - ) - - # 修改调用参数以包含 num_queries - original_invoke = retriever.llm_chain.invoke - - def new_invoke(input_dict): - input_dict["num_queries"] = self.num_queries - return original_invoke(input_dict) - - retriever.llm_chain.invoke = new_invoke - - return retriever - - @classmethod - def create_from_config( - cls, - llm: BaseLanguageModel, - config: Optional[dict] = None, - ) -> "MultiQueryTransformer": - """ - 从配置创建查询改写器 - - Args: - llm: 语言模型实例 - config: 配置字典 - - Returns: - MultiQueryTransformer 实例 - """ - config = config or {} - return cls( - llm=llm, - num_queries=config.get("num_queries", 3), - prompt_template=config.get("prompt_template", None), - ) - - -def create_rag_fusion_pipeline( - base_retriever: Any, - llm: BaseLanguageModel, - reranker: Optional[Any] = None, - num_queries: int = 3, -) -> Any: - """ - 创建完整的 RAG-Fusion 流水线 - - Args: - base_retriever: 基础检索器 - llm: 语言模型(用于查询改写) - reranker: 重排序器(可选) - num_queries: 查询改写数量 - - Returns: - 检索器实例 - """ - # 创建多路查询改写器 - query_transformer = MultiQueryTransformer( - llm=llm, - num_queries=num_queries, - ) - - # 创建多路查询检索器 - multi_query_retriever = query_transformer.create_multi_query_retriever( - base_retriever=base_retriever, - include_original=True, - ) - - # 如果提供了重排序器,则应用重排序 - if reranker is not None: - from langchain.retrievers import ContextualCompressionRetriever - return ContextualCompressionRetriever( - base_compressor=reranker, - base_retriever=multi_query_retriever, - ) - - return multi_query_retriever \ No newline at end of file + # 由于当前 LangChain 版本不支持 MultiQueryRetriever,暂时返回基础检索器 + # retriever = MultiQueryRetriever.from_llm( + # retriever=base_retriever, + # llm=self.llm, + # include_original=True + # ) + # + # # 自定义提示词 + # retriever.llm_chain.prompt = PromptTemplate.from_template( + # "你是一个专业的查询改写助手。你的任务是将用户的问题改写成 {num_queries} 个不同的版本。\n" + # "这些版本应该从不同的角度、使用不同的关键词来表达相同或相关的意图。\n\n" + # "原始问题: {question}\n\n" + # "请生成 {num_queries} 个不同版本的查询,每个版本一行。\n" + # "确保每个版本都是独立、完整的查询语句。\n\n" + # "生成 {num_queries} 个查询:" + # ) + # + # # 修改调用参数以包含 num_queries + # original_ainvoke = retriever.llm_chain.ainvoke + # async def new_ainvoke(input_dict): + # input_dict["num_queries"] = self.num_queries + # return await original_ainvoke(input_dict) + # retriever.llm_chain.ainvoke = new_ainvoke + # + # return retriever + return base_retriever diff --git a/app/rag/reranker.py b/app/rag/reranker.py index 4c7229f..4a414cf 100644 --- a/app/rag/reranker.py +++ b/app/rag/reranker.py @@ -1,141 +1,65 @@ """ -Cross-Encoder 重排序器 +重排序器模块 -使用 sentence-transformers 加载交叉编码器模型,对检索结果进行精排。 +使用 Cross-Encoder 模型对检索结果进行重排序,提高检索精度。 """ -import os -from typing import List, Dict, Any, Optional -from langchain.retrievers.document_compressors import CrossEncoderReranker +from typing import List from langchain_core.documents import Document -from sentence_transformers import CrossEncoder class CrossEncoderReranker: - """ - Cross-Encoder 重排序器包装类 + """使用 Cross-Encoder 对检索结果重排序。""" - 支持 BAAI/bge-reranker-base 等模型。 - """ - - def __init__( - self, - model_name: str = "BAAI/bge-reranker-base", - top_n: int = 5, - device: Optional[str] = None, - cache_folder: Optional[str] = None, - ): + def __init__(self, model_name: str = "BAAI/bge-reranker-base", top_n: int = 5): """ 初始化重排序器 Args: - model_name: 模型名称或路径 - top_n: 返回的顶部文档数量 - device: 设备(cpu/cuda),如果为 None 则自动选择 - cache_folder: 模型缓存目录 + model_name: 预训练模型名称 + top_n: 返回前 N 个结果 """ self.model_name = model_name self.top_n = top_n - self.device = device - self.cache_folder = cache_folder or os.path.join( - os.path.expanduser("~"), ".cache", "sentence_transformers" - ) + self.model = None - # 延迟加载模型 - self._model = None - self._langchain_reranker = None + # 尝试加载 Cross-Encoder 模型 + try: + from sentence_transformers import CrossEncoder + self.model = CrossEncoder(model_name) + except Exception as e: + print(f"警告: 无法加载 Cross-Encoder 模型 {model_name},将使用简单排序作为回退方案。错误: {e}") - def _load_model(self): - """加载交叉编码器模型""" - if self._model is None: - try: - self._model = CrossEncoder( - self.model_name, - device=self.device, - cache_folder=self.cache_folder, - ) - except Exception as e: - # 如果指定模型加载失败,尝试备用模型 - print(f"加载模型 {self.model_name} 失败: {e}") - print("尝试加载备用模型 BAAI/bge-reranker-v2-m3...") - self._model = CrossEncoder( - "BAAI/bge-reranker-v2-m3", - device=self.device, - cache_folder=self.cache_folder, - ) - - def _create_langchain_reranker(self): - """创建 LangChain 重排序器""" - if self._langchain_reranker is None: - self._load_model() - self._langchain_reranker = CrossEncoderReranker( - model=self._model, - top_n=self.top_n, - ) - - def rerank( - self, - query: str, - documents: List[Document], + def compress_documents( + self, documents: List[Document], query: str ) -> List[Document]: """ 对文档进行重排序 Args: - query: 查询文本 - documents: 待排序文档列表 - + documents: 待排序的文档列表 + query: 查询字符串 + Returns: - 重排序后的文档列表 + 排序后的文档列表 """ - self._create_langchain_reranker() - return self._langchain_reranker.compress_documents( - documents=documents, - query=query, - ) - - def create_contextual_compression_retriever( - self, - base_retriever: Any, - ) -> Any: - """ - 创建上下文压缩检索器 + if not documents: + return [] - Args: - base_retriever: 基础检索器 + # 如果模型加载失败,返回前 top_n 个文档 + if self.model is None: + return documents[:self.top_n] + + # 使用 Cross-Encoder 进行重排序 + try: + pairs = [[query, doc.page_content] for doc in documents] + scores = self.model.predict(pairs) - Returns: - 上下文压缩检索器 - """ - from langchain.retrievers import ContextualCompressionRetriever - - self._create_langchain_reranker() - - compression_retriever = ContextualCompressionRetriever( - base_compressor=self._langchain_reranker, - base_retriever=base_retriever, - ) - - return compression_retriever - - @classmethod - def create_from_config( - cls, - config: Optional[Dict[str, Any]] = None, - ) -> "CrossEncoderReranker": - """ - 从配置创建重排序器 - - Args: - config: 配置字典,包含 model_name, top_n, device 等 - - Returns: - CrossEncoderReranker 实例 - """ - config = config or {} - return cls( - model_name=config.get("model_name", "BAAI/bge-reranker-base"), - top_n=config.get("top_n", 5), - device=config.get("device", None), - cache_folder=config.get("cache_folder", None), - ) \ No newline at end of file + # 按分数降序排序 + scored_docs = sorted( + zip(documents, scores), key=lambda x: x[1], reverse=True + ) + return [doc for doc, _ in scored_docs[:self.top_n]] + except Exception as e: + print(f"警告: 重排序过程出错,将使用原始排序。错误: {e}") + return documents[:self.top_n] diff --git a/app/rag/retriever.py b/app/rag/retriever.py index 19a7511..80d6284 100644 --- a/app/rag/retriever.py +++ b/app/rag/retriever.py @@ -4,15 +4,12 @@ Qdrant 向量检索器 提供基础向量检索、混合检索(Dense + BM25)功能。 """ -import os from typing import List, Dict, Any, Optional -from langchain_qdrant import Qdrant +from langchain_qdrant import QdrantVectorStore from langchain.embeddings.base import Embeddings -from langchain.retrievers import ContextualCompressionRetriever -from langchain.retrievers.document_compressors import DocumentCompressorPipeline -from langchain.retrievers import EnsembleRetriever +# from langchain.retrievers import EnsembleRetriever from qdrant_client import QdrantClient -from qdrant_client.http import models +from rag_core import QDRANT_URL, QDRANT_API_KEY def create_qdrant_client( @@ -21,21 +18,21 @@ def create_qdrant_client( ) -> QdrantClient: """ 创建 Qdrant 客户端 - + Args: url: Qdrant 服务地址,默认从环境变量 QDRANT_URL 读取 api_key: API 密钥,默认从环境变量 QDRANT_API_KEY 读取 - + Returns: QdrantClient 实例 """ - url = url or os.getenv("QDRANT_URL", "http://localhost:6333") - api_key = api_key or os.getenv("QDRANT_API_KEY") - + url = url or QDRANT_URL + api_key = api_key or QDRANT_API_KEY + client_args = {"url": url} if api_key: client_args["api_key"] = api_key - + return QdrantClient(**client_args) @@ -44,34 +41,33 @@ def create_base_retriever( embeddings: Embeddings, search_kwargs: Optional[Dict[str, Any]] = None, client: Optional[QdrantClient] = None, -) -> Qdrant: +) -> QdrantVectorStore: """ 创建基础向量检索器 - + Args: collection_name: Qdrant 集合名称 embeddings: 嵌入模型 search_kwargs: 搜索参数,默认 {"k": 20} client: Qdrant 客户端,如果为 None 则自动创建 - + Returns: - Qdrant 检索器实例 + QdrantVectorStore 检索器实例 """ + search_kwargs = search_kwargs or {"k": 20} + + # 创建 Qdrant 客户端 if client is None: client = create_qdrant_client() - - search_kwargs = search_kwargs or {"k": 20} - - # 创建 Qdrant 检索器 - retriever = Qdrant.from_existing_collection( - embedding=embeddings, - collection_name=collection_name, + + # 使用 QdrantVectorStore 创建向量存储 + vector_store = QdrantVectorStore( client=client, - content_payload_key="content", # 假设存储的文本字段名为 "content" - metadata_payload_key="metadata", # 元数据字段名 + collection_name=collection_name, + embedding=embeddings, ) - - return retriever.as_retriever(search_kwargs=search_kwargs) + + return vector_store.as_retriever(search_kwargs=search_kwargs) def create_hybrid_retriever( @@ -80,65 +76,63 @@ def create_hybrid_retriever( dense_k: int = 10, sparse_k: int = 10, client: Optional[QdrantClient] = None, -) -> ContextualCompressionRetriever: +) -> QdrantVectorStore: """ 创建混合检索器(Dense Vector + BM25) - + Args: collection_name: Qdrant 集合名称 embeddings: 嵌入模型 dense_k: 向量检索返回数量 sparse_k: BM25 检索返回数量 client: Qdrant 客户端 - + Returns: 混合检索器 """ + # 创建 Qdrant 客户端 if client is None: client = create_qdrant_client() - - # 基础检索器(Qdrant 支持混合检索) - base_retriever = Qdrant.from_existing_collection( - embedding=embeddings, - collection_name=collection_name, + + # 使用 QdrantVectorStore 创建向量存储 + vector_store = QdrantVectorStore( client=client, - content_payload_key="content", - metadata_payload_key="metadata", + collection_name=collection_name, + embedding=embeddings, ) - - # 配置混合检索参数 + search_kwargs = { - "k": dense_k + sparse_k, # 总返回数量 - "score_threshold": 0.3, # 相似度阈值 + "k": dense_k + sparse_k, + "score_threshold": 0.3, } - - return base_retriever.as_retriever(search_kwargs=search_kwargs) + + return vector_store.as_retriever(search_kwargs=search_kwargs) -def create_ensemble_retriever( - retrievers: List[Any], - weights: Optional[List[float]] = None, - c: int = 60, -) -> EnsembleRetriever: - """ - 创建集成检索器,支持倒数排名融合 (RRF) - - Args: - retrievers: 检索器列表 - weights: 检索器权重 - c: RRF 常数(通常为60) - - Returns: - 集成检索器 - """ - if weights is None: - weights = [1.0 / len(retrievers)] * len(retrievers) - - ensemble = EnsembleRetriever( - retrievers=retrievers, - weights=weights, - c=c, - search_type="rrf", # 使用倒数排名融合 - ) - - return ensemble \ No newline at end of file +# def create_ensemble_retriever( +# retrievers: List[Any], +# weights: Optional[List[float]] = None, +# c: int = 60, +# ) -> EnsembleRetriever: +# """ +# 创建集成检索器,支持倒数排名融合 (RRF) +# +# Args: +# retrievers: 检索器列表 +# weights: 检索器权重 +# c: RRF 常数(通常为60) +# +# Returns: +# 集成检索器 +# """ +# if weights is None: +# weights = [1.0 / len(retrievers)] * len(retrievers) +# +# ensemble = EnsembleRetriever( +# retrievers=retrievers, +# weights=weights, +# c=c, +# search_type="rrf", +# ) +# +# return ensemble diff --git a/app/rag/tools.py b/app/rag/tools.py index 8dcf90f..a284a11 100644 --- a/app/rag/tools.py +++ b/app/rag/tools.py @@ -1,230 +1,89 @@ """ -RAG 工具包装 +RAG 工具模块 -将 RAG 流水线包装成 LangChain Tool,供 Agent 调用。 +将检索功能封装为 LangChain Tool,供 Agent 调用。 """ -from typing import Optional, Dict, Any -from langchain.tools import tool -from langchain_core.tools import Tool -from langchain_core.embeddings import Embeddings -from langchain_core.language_models import BaseLanguageModel - -from .pipeline import RAGPipeline, RAGConfig, RAGLevel +from langchain_core.tools import tool +from rag_core import LlamaCppEmbedder, QDRANT_URL, QDRANT_API_KEY +from .pipeline import RAGPipeline, RAGLevel -class RAGTool: - """ - RAG 工具包装器 +@tool +async def search_knowledge_base(query: str, rag_level: str = "rerank") -> str: + """在知识库中搜索与查询相关的文档片段。 - 将 RAG 流水线包装成 Agent 可调用的工具。 - """ - - def __init__( - self, - pipeline: RAGPipeline, - tool_name: str = "search_knowledge_base", - tool_description: str = None, - ): - """ - 初始化 RAG 工具 - - Args: - pipeline: RAG 流水线实例 - tool_name: 工具名称 - tool_description: 工具描述 - """ - self.pipeline = pipeline - self.tool_name = tool_name - - # 默认工具描述 - self.tool_description = tool_description or ( - "在知识库中搜索相关信息。当用户询问需要查阅文档、规章制度、" - "专业知识或需要基于已知信息回答的问题时使用此工具。" - "输入应为要搜索的查询文本。" - ) - - # 创建 LangChain 工具 - self._tool = self._create_tool() - - def _create_tool(self) -> Tool: - """创建 LangChain 工具""" - - @tool(self.tool_name, args_schema=None) - def search_knowledge_base(query: str) -> str: - """ - 在知识库中搜索相关信息 - - Args: - query: 搜索查询 - - Returns: - 格式化后的搜索结果 - """ - try: - # 执行检索 - result = self.pipeline.retrieve(query) - - if not result.documents: - return "在知识库中未找到相关信息。" - - # 格式化上下文 - context = self.pipeline.format_context( - result.documents, - max_length=4000, # 限制上下文长度 - ) - - # 构建响应 - response = ( - f"🔍 在知识库中找到了 {len(result.documents)} 条相关信息:\n\n" - f"{context}\n\n" - f"⏱️ 检索耗时: {result.query_time:.2f}秒" - ) - - return response - - except Exception as e: - error_msg = f"检索过程中发生错误: {str(e)}" - if self.pipeline.config.verbose: - print(f"RAG 工具错误: {error_msg}") - return error_msg - - # 设置工具描述 - search_knowledge_base.description = self.tool_description - - return search_knowledge_base - - def get_tool(self) -> Tool: - """获取 LangChain 工具""" - return self._tool - - def __call__(self, query: str) -> str: - """直接调用工具""" - return self._tool.invoke({"query": query}) - - -def create_rag_tool( - embeddings: Embeddings, - llm: Optional[BaseLanguageModel] = None, - config: Optional[Dict[str, Any]] = None, - tool_name: str = "search_knowledge_base", - tool_description: Optional[str] = None, -) -> Tool: - """ - 创建 RAG 工具(便捷函数) + 适用于事实性问题、背景知识查询。 Args: - embeddings: 嵌入模型 - llm: 语言模型(用于高级 RAG 功能) - config: RAG 配置字典 - tool_name: 工具名称 - tool_description: 工具描述 - + query: 查询字符串 + rag_level: 检索级别,可选值:basic(基础向量检索)、rerank(基础检索+重排序)、fusion(RAG-Fusion) + Returns: - LangChain Tool 实例 + 检索到的相关文档内容 """ + # 初始化嵌入模型 + embedder = LlamaCppEmbedder() + embeddings = embedder.as_langchain_embeddings() + # 创建 RAG 流水线 - pipeline = RAGPipeline.create_from_config( + pipeline = RAGPipeline( embeddings=embeddings, - llm=llm, - config_dict=config, + config={ + "rag_level": rag_level, + "collection_name": "rag_documents", + "rerank_top_n": 5, + } ) - # 创建工具包装器 - rag_tool = RAGTool( - pipeline=pipeline, - tool_name=tool_name, - tool_description=tool_description, - ) + # 执行检索 + try: + documents = await pipeline.aretrieve(query) + if not documents: + return "未找到相关信息。" + + # 格式化结果 + context = pipeline.format_context(documents) + return context + except Exception as e: + return f"检索过程中发生错误: {str(e)}" + + +@tool +def search_knowledge_base_sync(query: str, rag_level: str = "rerank") -> str: + """同步版本的知识库搜索工具。 - return rag_tool.get_tool() - - -# 导出便捷函数 -search_knowledge_base_tool = create_rag_tool - - -def bind_rag_to_agent( - agent_llm: BaseLanguageModel, - embeddings: Embeddings, - rag_llm: Optional[BaseLanguageModel] = None, - config: Optional[Dict[str, Any]] = None, - tool_name: str = "search_knowledge_base", -) -> BaseLanguageModel: - """ - 将 RAG 工具绑定到 Agent 模型 + 适用于事实性问题、背景知识查询。 Args: - agent_llm: Agent 使用的语言模型 - embeddings: 嵌入模型 - rag_llm: RAG 流水线使用的语言模型(如果与 agent_llm 不同) - config: RAG 配置 - tool_name: 工具名称 - - Returns: - 绑定工具后的模型 - """ - # 如果未指定 RAG LLM,使用 Agent LLM - if rag_llm is None: - rag_llm = agent_llm + query: 查询字符串 + rag_level: 检索级别,可选值:basic(基础向量检索)、rerank(基础检索+重排序)、fusion(RAG-Fusion) - # 创建 RAG 工具 - rag_tool = create_rag_tool( + Returns: + 检索到的相关文档内容 + """ + # 初始化嵌入模型 + embedder = LlamaCppEmbedder() + embeddings = embedder.as_langchain_embeddings() + + # 创建 RAG 流水线 + pipeline = RAGPipeline( embeddings=embeddings, - llm=rag_llm, - config=config, - tool_name=tool_name, + config={ + "rag_level": rag_level, + "collection_name": "rag_documents", + "rerank_top_n": 5, + } ) - # 绑定工具到模型 - return agent_llm.bind_tools([rag_tool]) - - -def create_agentic_rag_pipeline( - embeddings: Embeddings, - agent_llm: BaseLanguageModel, - rag_llm: Optional[BaseLanguageModel] = None, - config: Optional[Dict[str, Any]] = None, -) -> Dict[str, Any]: - """ - 创建完整的 Agentic RAG 流水线(Level 4) - - Args: - embeddings: 嵌入模型 - agent_llm: Agent 模型 - rag_llm: RAG 专用模型 - config: 配置 + # 执行检索 + try: + documents = pipeline.retrieve(query) + if not documents: + return "未找到相关信息。" - Returns: - 包含模型和工具的字典 - """ - # 配置 Agentic RAG 级别 - if config is None: - config = {} - config["rag_level"] = RAGLevel.AGENTIC.value - - # 创建 RAG 工具 - rag_tool = create_rag_tool( - embeddings=embeddings, - llm=rag_llm or agent_llm, - config=config, - tool_name="search_knowledge_base", - tool_description=( - "在知识库中搜索相关信息。当用户询问需要查阅文档、规章制度、" - "专业知识或需要基于已知信息回答的问题时使用此工具。" - "Agent 应该先判断是否需要使用此工具,然后调用它获取上下文。" - ), - ) - - # 绑定工具到模型 - bound_llm = agent_llm.bind_tools([rag_tool]) - - return { - "llm": bound_llm, - "tool": rag_tool, - "pipeline": RAGPipeline.create_from_config( - embeddings=embeddings, - llm=rag_llm or agent_llm, - config_dict=config, - ), - } \ No newline at end of file + # 格式化结果 + context = pipeline.format_context(documents) + return context + except Exception as e: + return f"检索过程中发生错误: {str(e)}" diff --git a/rag_core/__init__.py b/rag_core/__init__.py new file mode 100644 index 0000000..c6aa7a6 --- /dev/null +++ b/rag_core/__init__.py @@ -0,0 +1,18 @@ +""" +RAG Core - 公共 RAG 组件包 + +提供嵌入模型、向量存储和文档存储的公共功能,被 rag_indexer 和 app/rag 共用。 +""" + +from .embedders import LlamaCppEmbedder +from .vector_store import QdrantVectorStore, QDRANT_URL, QDRANT_API_KEY +from .store import PostgresDocStore, create_docstore + +__all__ = [ + "LlamaCppEmbedder", + "QdrantVectorStore", + "QDRANT_URL", + "QDRANT_API_KEY", + "PostgresDocStore", + "create_docstore", +] diff --git a/rag_indexer/embedders.py b/rag_core/embedders.py similarity index 94% rename from rag_indexer/embedders.py rename to rag_core/embedders.py index 3f1be8a..e9a87a3 100644 --- a/rag_indexer/embedders.py +++ b/rag_core/embedders.py @@ -64,12 +64,9 @@ class LlamaCppEmbedder: response.raise_for_status() data = response.json() - # 处理不同响应格式 if isinstance(data, list): - # llama.cpp 直接返回列表 return [item["embedding"] for item in data] elif isinstance(data, dict) and "data" in data: - # OpenAI 标准格式 return [item["embedding"] for item in sorted(data["data"], key=lambda x: x["index"])] else: raise ValueError(f"未知的嵌入 API 响应格式: {data}") @@ -85,4 +82,4 @@ class _LlamaCppLangchainAdapter(Embeddings): return self._embedder.embed_documents(texts) def embed_query(self, text: str) -> List[float]: - return self._embedder.embed_query(text) + return self._embedder.embed_query(text) \ No newline at end of file diff --git a/rag_indexer/store/__init__.py b/rag_core/store/__init__.py similarity index 92% rename from rag_indexer/store/__init__.py rename to rag_core/store/__init__.py index a1e561e..359db76 100644 --- a/rag_indexer/store/__init__.py +++ b/rag_core/store/__init__.py @@ -5,7 +5,7 @@ - PostgresDocStore: PostgreSQL 数据库存储(生产环境) 示例用法: - >>> from rag_indexer.store import create_docstore + >>> from rag_core.store import create_docstore >>> # 创建 PostgreSQL 存储 >>> store, conn = create_docstore( diff --git a/rag_indexer/store/factory.py b/rag_core/store/factory.py similarity index 98% rename from rag_indexer/store/factory.py rename to rag_core/store/factory.py index 2388f8f..c32c2c7 100644 --- a/rag_indexer/store/factory.py +++ b/rag_core/store/factory.py @@ -70,4 +70,4 @@ def create_docstore( return store, conn_str else: - raise ValueError(f"不支持的存储类型: {store_type}。目前仅支持: postgres") \ No newline at end of file + raise ValueError(f"不支持的存储类型: {store_type}。目前仅支持: postgres") diff --git a/rag_indexer/store/postgres.py b/rag_core/store/postgres.py similarity index 99% rename from rag_indexer/store/postgres.py rename to rag_core/store/postgres.py index 69ef4e3..5132355 100644 --- a/rag_indexer/store/postgres.py +++ b/rag_core/store/postgres.py @@ -246,4 +246,4 @@ class PostgresDocStore(BaseStore[str, Any]): 注意:在异步环境中,请使用 aclose 方法。 """ - pass \ No newline at end of file + pass diff --git a/rag_indexer/vector_store.py b/rag_core/vector_store.py similarity index 96% rename from rag_indexer/vector_store.py rename to rag_core/vector_store.py index 4f2e5c6..5faa66f 100644 --- a/rag_indexer/vector_store.py +++ b/rag_core/vector_store.py @@ -9,11 +9,8 @@ from typing import List, Optional, Dict, Any from langchain_core.documents import Document from langchain_qdrant import QdrantVectorStore as LangchainQdrantVS from qdrant_client import QdrantClient -from qdrant_client.http import models from qdrant_client.http.models import Distance, VectorParams -from .embedders import LlamaCppEmbedder - logger = logging.getLogger(__name__) QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333") @@ -31,17 +28,15 @@ class QdrantVectorStore: self.collection_name = collection_name self._client: Optional[QdrantClient] = None - # 嵌入模型 if embeddings is None: + from .embedders import LlamaCppEmbedder embedder = LlamaCppEmbedder() self.embeddings = embedder.as_langchain_embeddings() else: self.embeddings = embeddings - # 先创建集合 self.create_collection() - # LangChain 向量存储 self.vector_store = LangchainQdrantVS( client=self.get_client(), collection_name=self.collection_name, @@ -68,6 +63,7 @@ class QdrantVectorStore: def create_collection(self, vector_size: Optional[int] = None, force_recreate: bool = False): """创建集合,设置合适的向量维度。""" if vector_size is None: + from .embedders import LlamaCppEmbedder embedder = LlamaCppEmbedder() vector_size = embedder.get_embedding_dimension() diff --git a/rag_indexer/IndexBuilder.py b/rag_indexer/IndexBuilder.py new file mode 100644 index 0000000..6f077e9 --- /dev/null +++ b/rag_indexer/IndexBuilder.py @@ -0,0 +1,299 @@ +""" +离线 RAG 索引构建核心流水线。 + +使用 LangChain 的 ParentDocumentRetriever 实现父子块策略。 +""" + +import asyncio +import logging +from dataclasses import dataclass, field +from pathlib import Path +from typing import List, Union, Optional, Any, Dict, Tuple + +from httpx import RemoteProtocolError +from langchain_core.documents import Document +from langchain_core.embeddings import Embeddings +from langchain_core.stores import BaseStore +from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter +from langchain_classic.retrievers import ParentDocumentRetriever + +from .loaders import DocumentLoader +from .splitters import SplitterType, get_splitter, SemanticChunkerAdapter +from rag_core import LlamaCppEmbedder, QdrantVectorStore, create_docstore + +logger = logging.getLogger(__name__) + + +# ---------- 配置数据类 ---------- +@dataclass +class DocstoreConfig: + """文档存储配置(用于父块存储)。""" + connection_string: Optional[str] = None + pool_config: Optional[Dict[str, Any]] = None + max_concurrency: Optional[int] = None + # 若要从外部注入已创建好的 docstore,可直接设置此字段 + instance: Optional[BaseStore] = None + + +@dataclass +class IndexBuilderConfig: + """索引构建器配置。""" + collection_name: str = "rag_documents" + splitter_type: SplitterType = SplitterType.PARENT_CHILD + + # 父块切分参数(仅当 splitter_type 为 PARENT_CHILD 时生效) + parent_chunk_size: int = 1000 + parent_chunk_overlap: int = 100 + # 子块切分参数 + child_chunk_size: int = 200 + child_chunk_overlap: int = 20 + child_splitter_type: SplitterType = SplitterType.SEMANTIC # 子块默认语义切分 + + # 检索参数 + search_k: int = 5 + + # 文档存储配置(仅父子块模式需要) + docstore: DocstoreConfig = field(default_factory=DocstoreConfig) + + # 其他切分器参数(当 splitter_type 非父子块时使用) + extra_splitter_kwargs: Dict[str, Any] = field(default_factory=dict) + + +# ---------- 索引构建器 ---------- +class IndexBuilder: + """RAG 索引构建主流水线,支持单块切分与父子块切分。""" + + def __init__(self, config: Optional[IndexBuilderConfig] = None, **kwargs): + """ + Args: + config: 索引构建器配置对象,优先级高于 kwargs + **kwargs: 可直接传入配置参数,会合并到 config 中(为方便使用保留) + """ + if config is None: + config = IndexBuilderConfig(**kwargs) + elif kwargs: + # 合并 kwargs 到 config 的字段(仅更新已有字段) + for key, value in kwargs.items(): + if hasattr(config, key): + setattr(config, key, value) + + self.config = config + self._docstore_conn: Optional[str] = None # 用于记录由 create_docstore 创建的连接信息 + + # 初始化基础组件 + self.loader = DocumentLoader() + self.embedder = LlamaCppEmbedder() + self.embeddings: Embeddings = self.embedder.as_langchain_embeddings() + + # 初始化向量存储 + self.vector_store = QdrantVectorStore( + collection_name=config.collection_name, + embeddings=self.embeddings, + ) + + # 根据切分类型初始化相关组件 + self._init_splitters_and_retriever() + + # ---------- 私有初始化方法 ---------- + def _init_splitters_and_retriever(self) -> None: + """根据配置初始化切分器和检索器。""" + if self.config.splitter_type == SplitterType.PARENT_CHILD: + self._init_parent_child_mode() + else: + self._init_single_splitter_mode() + + def _init_single_splitter_mode(self) -> None: + """单一切分模式(递归或语义)。""" + splitter_kwargs = self.config.extra_splitter_kwargs.copy() + if self.config.splitter_type == SplitterType.SEMANTIC: + splitter_kwargs["embeddings"] = self.embeddings + self.splitter = get_splitter(self.config.splitter_type, **splitter_kwargs) + self.retriever = None + self.docstore = None + logger.info("使用单一 %s 切分器", self.config.splitter_type.value) + + def _init_parent_child_mode(self) -> None: + """父子块切分模式,初始化父块/子块切分器、文档存储和检索器。""" + cfg = self.config + + # 父块切分器(始终使用递归切分) + self.parent_splitter = RecursiveCharacterTextSplitter( + chunk_size=cfg.parent_chunk_size, + chunk_overlap=cfg.parent_chunk_overlap, + ) + + # 子块切分器 + if cfg.child_splitter_type == SplitterType.SEMANTIC: + self.child_splitter = get_splitter( + SplitterType.SEMANTIC, + embeddings=self.embeddings, + **cfg.extra_splitter_kwargs + ) + logger.info("子块使用语义切分器") + else: + self.child_splitter = RecursiveCharacterTextSplitter( + chunk_size=cfg.child_chunk_size, + chunk_overlap=cfg.child_chunk_overlap, + ) + logger.info("子块使用递归切分器,块大小=%d,重叠=%d", + cfg.child_chunk_size, cfg.child_chunk_overlap) + + # 初始化文档存储(用于父块) + self.docstore = self._create_or_use_docstore() + + # 创建检索器 + self.retriever = ParentDocumentRetriever( + vectorstore=self.vector_store.get_langchain_vectorstore(), + docstore=self.docstore, + child_splitter=self.child_splitter, # type: ignore[arg-type] + parent_splitter=self.parent_splitter, + search_kwargs={"k": cfg.search_k}, + ) + logger.info("ParentDocumentRetriever 初始化完成,父块大小=%d", cfg.parent_chunk_size) + + def _create_or_use_docstore(self) -> BaseStore: + """创建或获取文档存储实例。""" + cfg = self.config.docstore + if cfg.instance is not None: + logger.debug("使用外部注入的文档存储") + return cfg.instance + + # 使用 create_docstore 创建 PostgreSQL 存储 + docstore, conn_info = create_docstore( + connection_string=cfg.connection_string, + pool_config=cfg.pool_config, + max_concurrency=cfg.max_concurrency, + ) + self._docstore_conn = conn_info + logger.info("文档存储已创建(PostgreSQL)") + return docstore + + # ---------- 公共构建方法 ---------- + async def build_from_file(self, file_path: Union[str, Path]) -> int: + """从单个文件构建索引。""" + logger.info("加载文件: %s", file_path) + documents = self.loader.load_file(file_path) + logger.info("已加载 %d 个文档", len(documents)) + return await self._process_documents(documents) + + async def build_from_directory( + self, directory_path: Union[str, Path], recursive: bool = True + ) -> int: + """从目录递归构建索引。""" + logger.info("加载目录: %s (递归=%s)", directory_path, recursive) + documents = self.loader.load_directory(directory_path, recursive=recursive) + logger.info("已从目录加载 %d 个文档", len(documents)) + return await self._process_documents(documents) + + async def _process_documents(self, documents: List[Document]) -> int: + """处理文档列表,分发给相应的索引逻辑。""" + if not documents: + logger.warning("没有文档需要处理") + return 0 + + if self.config.splitter_type == SplitterType.PARENT_CHILD: + return await self._index_with_parent_child(documents) + else: + return await self._index_with_single_splitter(documents) + + async def _index_with_single_splitter(self, documents: List[Document]) -> int: + """单一模式:切分后直接写入向量库。""" + chunks = self.splitter.split_documents(documents) # type: ignore[union-attr] + logger.info("已切分为 %d 个块", len(chunks)) + + self.vector_store.create_collection() + self.vector_store.add_documents(chunks) + return len(chunks) + + async def _index_with_parent_child(self, documents: List[Document]) -> int: + """父子模式:使用 ParentDocumentRetriever 批量添加。""" + self.vector_store.create_collection() + assert self.retriever is not None + + batch_size = 10 + total = len(documents) + processed = 0 + + for i in range(0, total, batch_size): + batch = documents[i:i + batch_size] + await self._add_batch_with_retry(batch, i // batch_size + 1) + processed += len(batch) + logger.info("批次 %d: 已处理 %d/%d", i // batch_size + 1, processed, total) + + logger.info("ParentDocumentRetriever 索引完成,共处理 %d 个文档", processed) + return processed + + async def _add_batch_with_retry(self, batch: List[Document], batch_no: int) -> None: + """添加批次,失败时自动重试(处理网络波动)。""" + max_retries = 3 + for attempt in range(max_retries): + try: + await self.retriever.aadd_documents(batch) # type: ignore[union-attr] + return + except (RemoteProtocolError, ConnectionError, OSError) as e: + if attempt == max_retries - 1: + raise + logger.warning("批次 %d 连接断开,重试 (%d/%d): %s", + batch_no, attempt + 1, max_retries, e) + self.vector_store.refresh_client() + await asyncio.sleep(1) + + # ---------- 信息获取方法 ---------- + def get_collection_info(self) -> Any: + """获取向量库集合信息。""" + return self.vector_store.get_collection_info() + + def get_child_splitter(self) -> TextSplitter: + """获取当前使用的子块切分器。""" + if self.config.splitter_type == SplitterType.PARENT_CHILD: + return self.child_splitter # type: ignore[return-value] + return self.splitter # type: ignore[return-value] + + def get_parent_splitter(self) -> RecursiveCharacterTextSplitter: + """获取父块切分器(仅父子模式可用)。""" + if self.config.splitter_type != SplitterType.PARENT_CHILD: + raise RuntimeError("父块切分器仅在父子块模式下可用") + return self.parent_splitter # type: ignore[return-value] + + def get_docstore(self) -> BaseStore: + """获取文档存储实例(仅父子模式可用)。""" + if self.config.splitter_type != SplitterType.PARENT_CHILD: + raise RuntimeError("文档存储仅在父子块模式下可用") + assert self.docstore is not None + return self.docstore + + # ---------- 资源管理 ---------- + def close(self) -> None: + """关闭资源(同步版本,供上下文管理器使用)。""" + if self.docstore is not None and hasattr(self.docstore, "aclose"): + try: + loop = asyncio.get_running_loop() + except RuntimeError: + # 无运行中的事件循环,创建临时循环 + loop = asyncio.new_event_loop() + loop.run_until_complete(self.docstore.aclose()) # type: ignore[attr-defined] + loop.close() + else: + # 已有运行中的循环,创建任务(用户自行等待) + loop.create_task(self.docstore.aclose()) # type: ignore[attr-defined] + logger.info("IndexBuilder 资源已关闭") + + async def aclose(self) -> None: + """异步关闭资源。""" + if self.docstore is not None and hasattr(self.docstore, "aclose"): + await self.docstore.aclose() # type: ignore[attr-defined] + logger.info("IndexBuilder 资源已异步关闭") + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + return False + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.aclose() + return False \ No newline at end of file diff --git a/rag_indexer/README.md b/rag_indexer/README.md index cb2df4c..c4259c8 100644 --- a/rag_indexer/README.md +++ b/rag_indexer/README.md @@ -2,35 +2,13 @@ 该模块负责 RAG 系统的阶段一:**离线索引构建**。它将外部的非结构化数据(如文档、PDF、网页等)清洗、切分并转化为向量,最终存入向量数据库中。 -## 📊 系统工作流示意图 - -```mermaid -graph TD - A[原始文档集合
PDF / Word / Markdown] --> B(文档加载器 DocumentLoader) - B --> C{文本切分策略 Splitter} - - C -->|基础策略| D1[固定字符长度切分
Recursive Split] - C -->|进阶策略| D2[语义边界切分
Semantic Chunking] - C -->|高级策略| D3[父子文档切分
Parent-Child / Auto-merging] - - D1 & D2 & D3 --> E[向量化 Embedder
llama.cpp: embeddinggemma] - - E --> F[(Qdrant 向量数据库)] - - subgraph "元数据管理" - G[提取作者、日期、页码等元数据 Metadata] -.附加.-> E - end -``` - ---- - ## 🎯 演进路线与核心算法 (Roadmap) ### Level 1: 基础暴力切分 (Basic Recursive Splitting) -- **核心算法**: 递归字符切分。它按照预定义的分隔符列表(如 `["\n\n", "\n", " ", ""]`)从大到小尝试切分文本,直到每块的大小满足最大长度限制。 +- **核心算法**: 递归字符切分。它按照预定义的分隔符列表(如 `["\n\n", "\n", "。", "!", "?", " ", ""]`)从大到小尝试切分文本,直到每块的大小满足最大长度限制。 - **优缺点**: 实现极简单,速度快。但非常容易将一句话拦腰截断,导致上下文语义丢失。 -- **实现指南**: - - 从 `langchain.text_splitter` 导入 `RecursiveCharacterTextSplitter`。 +- **实现指南**: + - 从 `langchain_text_splitters` 导入 `RecursiveCharacterTextSplitter`。 - 实例化时设置 `chunk_size`(如 500)和 `chunk_overlap`(如 50),直接调用 `.split_documents(raw_docs)` 方法。 ### Level 2: 语义动态切分 (Semantic Chunking) @@ -38,58 +16,52 @@ graph TD 1. 将文章按标点符号按句子拆分。 2. 使用轻量级 Embedding 模型将每一句向量化。 3. 计算相邻两句之间的余弦相似度 (Cosine Similarity)。 - 4. 当相似度低于设定阈值时(说明两句话讲的不是同一件事,语义发生了转折),在此处“切断”形成一个新的块。 + 4. 当相似度低于设定阈值时(说明两句话讲的不是同一件事,语义发生了转折),在此处"切断"形成一个新的块。 - **优缺点**: 极大程度保留了段落内语义的连贯性,对 LLM 回答非常友好。但由于在切分阶段就需要调用向量模型,耗时略长。 -- **实现指南**: +- **实现指南**: + - 从 `langchain_text_splitters` 导入 `TextSplitter` 作为基类。 - 从 `langchain_experimental.text_splitter` 导入 `SemanticChunker`。 - - 实例化时需要传入你已经配置好的 Embedding 模型实例(如基于 `OpenAIEmbeddings` 封装的 llama.cpp 本地模型),并设置 `breakpoint_threshold_type="percentile"` 等阈值参数。 + - 实现 `SemanticChunkerAdapter` 继承 `TextSplitter`,解决类型不兼容问题。 + - 实例化时需要传入你已经配置好的 Embedding 模型实例(如基于 `LlamaCppEmbedder` 封装的本地模型)。 ### Level 3: 高级父子块策略 (Parent-Child / Auto-merging) - **核心算法**: 层次化双重存储与映射。 - - **切分机制**: 首先将文档粗切为较大的“父块 (Parent Chunk, 约 1000 词)”,随后将父块细切为较小的“子块 (Child Chunk, 约 200 词)”。 - - **存储机制**: 仅仅将**子块**的向量存入 Qdrant 用于精准计算距离;将**父块**的原始内容存在内存或 Document Store (如 KV 数据库) 中,通过 UUID 相互映射。 + - **切分机制**: 首先将文档粗切为较大的"父块 (Parent Chunk, 约 1000 字符)",随后将父块细切为较小的"子块 (Child Chunk, 约 200 字符)"。 + - **存储机制**: 仅仅将**子块**的向量存入 Qdrant 用于精准计算距离;将**父块**的原始内容存在 PostgreSQL DocStore 中,通过 UUID 相互映射。 - **核心思路**: 解决 RAG 领域经典的矛盾——检索时块越小越容易精确命中(去除噪声);但生成回答时,块越大越能给大模型提供充足的上下文背景。 -- **实现指南**: - - 使用 `langchain.retrievers` 中的 `ParentDocumentRetriever` 模块。 +- **实现指南**: + - 使用 `langchain_classic.retrievers` 中的 `ParentDocumentRetriever` 模块。 - 在写入时,你需要同时准备一个底层的 `VectorStore` (即 Qdrant) 和一个 `BaseStore`。 - - **推荐方案**: 使用 `LocalFileStore` (默认) 或 `PostgresDocStore` 作为 docstore。 + - **推荐方案**: 使用 `PostgresDocStore` 作为 docstore,支持持久化存储。 - 将两种不同的 `TextSplitter` 分别赋值给检索器的 `child_splitter` 和 `parent_splitter`,然后调用 `.add_documents()` 即可让系统自动完成映射。 ### Level 3.1: PostgreSQL DocStore 集成 -- **核心优势**: 利用 PostgreSQL 作为持久化存储,适合生产环境。使用同步连接池,避免异步复杂度。 +- **核心优势**: 利用 PostgreSQL 作为持久化存储,适合生产环境。使用异步连接池,支持高并发。 - **实现步骤**: - 1. **安装依赖**: `pip install psycopg2-binary` - 2. **配置连接**: 设置 `DB_URI` 环境变量或直接在代码中指定 PostgreSQL 连接字符串 - 3. **创建 docstore**: 使用 `PostgresDocStore` 类直接创建 - 4. **注入到 IndexBuilder**: 在创建 `IndexBuilder` 时通过 `docstore` 参数注入 + 1. **配置连接**: 设置 `DB_URI` 环境变量或通过 `docstore_conn_string` 参数指定 + 2. **创建 docstore**: 使用 `rag_indexer.store.create_docstore()` 工厂函数 + 3. **注入到 IndexBuilder**: 通过构造函数参数注入 - **使用示例**: ```python - from rag_indexer.docstore_manager import PostgresDocStore from rag_indexer.builder import IndexBuilder, SplitterType - # 创建 PostgreSQL docstore - docstore = PostgresDocStore( - connection_string="postgresql://user:pass@host:5432/db", - table_name="parent_documents" - ) - - # 创建 IndexBuilder 并注入 docstore + # 创建 IndexBuilder builder = IndexBuilder( collection_name="rag_documents", splitter_type=SplitterType.PARENT_CHILD, - docstore=docstore, parent_chunk_size=1000, child_chunk_size=200, + docstore_conn_string="postgresql://user:pass@host:5432/db", ) ``` ### Level 3.2: 语义切分与父子块策略结合 - **核心优势**: 结合语义切分的连贯性和父子块策略的层次化存储优势,实现更精准的检索和更丰富的上下文。 - **实现原理**: - - **父块切分**: 使用递归字符切分创建大块(约1000词),提供完整的上下文背景 - - **子块切分**: 使用语义动态切分创建小块(约200词),根据语义连贯性动态切分,提高检索精度 - - **存储机制**: 子块向量存入Qdrant用于精准检索,父块内容存入PostgreSQL提供完整上下文 + - **父块切分**: 使用 `RecursiveCharacterTextSplitter` 创建大块(约1000字符),提供完整的上下文背景 + - **子块切分**: 使用 `SemanticChunkerAdapter` 创建小块,根据语义连贯性动态切分,提高检索精度 + - **存储机制**: 子块向量存入 Qdrant 用于精准检索,父块内容存入 PostgreSQL 提供完整上下文 - **使用示例**: ```python from rag_indexer.builder import IndexBuilder, SplitterType @@ -109,97 +81,55 @@ graph TD ``` - **配置参数**: - `child_splitter_type`: 子块切分器类型,可选 `SplitterType.RECURSIVE`(默认)或 `SplitterType.SEMANTIC` - - 当使用语义切分时,系统会自动使用已配置的Embedding模型进行句子级相似度计算 + - 当使用语义切分时,系统会自动使用已配置的 Embedding 模型进行句子级相似度计算 -### Level 4: RAG-Fusion (多路改写与倒数排名融合) -- **核心优势**: 通过大模型发散思维,将单一问题改写为多个相似问题,扩大搜索面,再利用数学统计算法合并结果,提高检索的全面性和准确性。 -- **实现原理**: - 1. **多路查询改写**: 利用LLM将原始查询改写成3-5个不同表述的查询,从不同角度表达相同意图 - 2. **倒数排名融合 (RRF)**: 对每个改写查询的结果进行RRF融合,公式为 $RRF\_score(d) = \sum_{q \in Q} \frac{1}{k + rank_q(d)}$,避免单一检索结果主导 - 3. **结果去重**: 对融合后的结果进行去重,确保返回的文档唯一 -- **使用示例**: - ```python - from rag_indexer.builder import IndexBuilder, SplitterType - from langchain_openai import OpenAI - - # 创建 IndexBuilder - builder = IndexBuilder( - collection_name="rag_documents", - splitter_type=SplitterType.PARENT_CHILD, - parent_chunk_size=1000, - child_chunk_size=200, - docstore_conn_string="postgresql://user:pass@host:5432/db", - ) - - # 创建语言模型用于查询改写 - llm = OpenAI( - openai_api_base="http://localhost:8000/v1", - openai_api_key="no-key-needed", - model_name="Qwen2.5-7B-Instruct", - temperature=0.3, - ) - - # 使用 RAG-Fusion 检索 - query = "如何申请项目资金?" - results = builder.retrieve_with_fusion( - query=query, - llm=llm, - num_queries=3, - k=5, - return_parent=True - ) - ``` -- **配置参数**: - - `llm`: 语言模型实例,用于查询改写 - - `num_queries`: 生成的查询数量,建议3-5个 - - `k`: 返回的文档数量 - - `return_parent`: 是否返回父块上下文 - -### Level 5: GraphRAG 与 多模态 (Graph & Multi-modal) +### Level 4: GraphRAG(基于图和关系的 RAG) - **核心算法**: LLM 实体关系抽取 (NER & Relation Extraction)。 -- **核心思路**: 解决传统纯向量检索难以处理“跨文档复杂关系推理”的痛点(如:A公司的CEO是谁?他名下的B公司主要业务是什么?这种需要横跨多页 PDF 的跳跃性问题)。 -- **实现指南**: - - 使用本地的大模型(如 `Gemma-4-E2B`)配合 `langchain_community.graphs` 模块。 - - 利用 `LLMGraphTransformer` 组件,在读取文档时,通过预设的 Prompt 强制大模型提取出实体(Node)和关系(Edge),直接写入诸如 Neo4j 这样的图数据库中,而非传统的 Qdrant 向量库。 +- **核心思路**: 解决传统纯向量检索难以处理"跨文档复杂关系推理"的痛点(如:A公司的CEO是谁?他名下的B公司主要业务是什么?这种需要横跨多页 PDF 的跳跃性问题)。 +- **实现原理**: + 1. **实体提取**: 利用 LLM 从文档中提取实体(如人物、组织、地点、事件等) + 2. **关系抽取**: 识别实体之间的关系(如"CEO of"、"founded by"、"located in"等) + 3. **图构建**: 将实体作为节点,关系作为边,构建知识图谱 + 4. **混合检索**: 结合向量检索和图查询,同时利用语义相似性和结构关系 +- **技术栈**: + - **图数据库**: Neo4j 或 RedisGraph + - **LLM 工具**: `LLMGraphTransformer` 或自定义 Prompt + - **集成方式**: 与向量存储并行,形成混合检索系统 +- **实现指南**: + - 使用 `langchain_community.graphs` 模块 + - 配置本地大模型(如 `Gemma-4-E2B`)用于实体关系抽取 + - 构建包含实体和关系的图结构,存储到图数据库 + - 实现混合检索逻辑,结合向量相似度和图路径分析 ---- - -## 所需依赖与安装 - -为了支持完整的文档解析和 Qdrant 写入,需要安装以下 Python 包: - -```bash -# 基础核心库 -pip install langchain langchain-core langchain-openai langchain-qdrant - -# 用于复杂文档解析 (PDF, Word, Excel 等) -pip install unstructured pdf2image pdfminer.six - -# 用于语义分块 (可选) -pip install langchain-experimental - -# 用于 PostgreSQL 存储 (可选,用于 Parent-Child 策略) -pip install psycopg2-binary - -# 用于 RAG-Fusion (可选,需要语言模型) -pip install langchain-openai -``` +### Level 5: 多模态 RAG (Multi-modal RAG) +- **核心算法**: 跨模态嵌入和多模态融合。 +- **核心思路**: 突破纯文本限制,支持图像、表格、音频等多种数据类型的理解和检索。 +- **实现原理**: + 1. **多模态嵌入**: 使用 CLIP 等模型将不同模态数据映射到统一向量空间 + 2. **多模态索引**: 为不同类型的内容创建专用索引 + 3. **跨模态检索**: 支持以文搜图、以图搜文等跨模态查询 +- **技术栈**: + - **多模态模型**: CLIP、BLIP 等 + - **存储**: 向量数据库 + 对象存储 + - **检索**: 混合向量检索 --- ## 📂 架构与文件结构设计 -在 `rag_indexer/` 目录下,需创建以下核心文件: - -```text +``` rag_indexer/ ├── __init__.py ├── loaders.py # 负责调用 unstructured 解析不同类型文件 -├── splitters.py # 负责实现 Recursive、Semantic、Parent-Child 切分逻辑 +├── splitters.py # 负责实现 Recursive、Semantic 切分逻辑及适配器 ├── embedders.py # 封装本地 llama.cpp 交互的 Embedding 接口 ├── vector_store.py # 封装 Qdrant 写入、Upsert、Collection 初始化操作 -├── docstore_manager.py # 文档存储管理器,支持 LocalFileStore 和 PostgreSQL -└── builder.py # 核心编排文件,将上述模块串联成 Pipeline +├── builder.py # 核心编排文件,将上述模块串联成 Pipeline +├── cli.py # 命令行入口 +└── store/ + ├── __init__.py + ├── factory.py # docstore 工厂函数 + └── postgres.py # PostgreSQL DocStore 实现 ``` --- @@ -211,36 +141,36 @@ rag_indexer/ ``` ┌─────────────────────────────────────────┐ │ builder.py │ - │ IndexBuilder 入口 │ + │ IndexBuilder 入口 │ └─────────────────┬───────────────────────┘ │ ┌─────────────────▼───────────────────────┐ - │ loaders.py │ - │ DocumentLoader.load_file() │ - │ → 返回 List[Document] │ + │ loaders.py │ + │ DocumentLoader.load_file() │ + │ → 返回 List[Document] │ └─────────────────┬───────────────────────┘ │ ┌─────────────────▼───────────────────────┐ - │ ParentDocumentRetriever.add_documents()│ - │ ┌─────────────────────────────────┐ │ - │ │ parent_splitter (粗切) │ │ - │ │ 父块 ~1000 词 │ │ - │ └────────────┬────────────────────┘ │ - │ │ │ - │ ┌────────────▼────────────────────┐ │ - │ │ child_splitter (细切) │ │ - │ │ 子块 ~200 词 │ │ - │ └────────────┬────────────────────┘ │ - │ │ │ - │ ┌──────────┴──────────┐ │ - │ ▼ ▼ │ - │ 子块向量 父块原始内容 │ - │ │ │ │ - │ ▼ ▼ │ - │ ┌────────────┐ ┌─────────────────┐ │ - │ │vector_store│ │ docstore_manager│ │ - │ │ (Qdrant) │ │ (PostgreSQL) │ │ - │ └────────────┘ └─────────────────┘ │ + │ ParentDocumentRetriever │ + │ ┌─────────────────────────────────┐ │ + │ │ parent_splitter (粗切) │ │ + │ │ 父块 ~1000 字符 │ │ + │ └────────────┬────────────────────┘ │ + │ │ │ + │ ┌────────────▼────────────────────┐ │ + │ │ child_splitter (细切) │ │ + │ │ 子块 ~200 字符 │ │ + │ └────────────┬────────────────────┘ │ + │ │ │ + │ ┌──────────┴──────────┐ │ + │ ▼ ▼ │ + │ 子块向量 父块原始内容 │ + │ │ │ │ + │ ▼ ▼ │ + │ ┌────────────┐ ┌─────────────────┐ │ + │ │vector_store│ │ store/ │ │ + │ │ (Qdrant) │ │ (PostgreSQL) │ │ + │ └────────────┘ └─────────────────┘ │ └─────────────────────────────────────────┘ ``` @@ -250,10 +180,31 @@ rag_indexer/ |------|------|------------| | **builder.py** | 核心编排,负责串联整个流程 | `IndexBuilder` | | **loaders.py** | 解析各种文档格式(PDF、Word、TXT等) | `DocumentLoader` | -| **splitters.py** | 文本切分策略(Recursive/Semantic/Parent-Child) | `SplitterType`, `get_splitter()` | +| **splitters.py** | 文本切分策略(Recursive/Semantic)及适配器 | `SplitterType`, `get_splitter()`, `SemanticChunkerAdapter` | | **embedders.py** | 向量化(封装 llama.cpp embedding 接口) | `LlamaCppEmbedder` | | **vector_store.py** | Qdrant 向量数据库操作 | `QdrantVectorStore` | -| **docstore_manager.py** | 父文档存储(PostgreSQL/本地文件) | `PostgresDocStore`, `get_docstore()` | +| **store/postgres.py** | PostgreSQL DocStore 实现 | `PostgresDocStore` | +| **store/factory.py** | docstore 工厂函数 | `create_docstore()` | + +### 核心实现细节 + +#### 1. 文本切分 +- **递归切分**: 使用 `langchain_text_splitters.RecursiveCharacterTextSplitter`,支持中文分隔符 +- **语义切分**: 使用 `langchain_experimental.text_splitter.SemanticChunker`,通过 `SemanticChunkerAdapter` 适配 `TextSplitter` 接口 +- **父子块策略**: 父块使用递归切分(1000字符),子块可选择递归或语义切分(200字符) + +#### 2. 向量化 +- **Embedding API**: 使用 `LlamaCppEmbedder` 封装本地 llama.cpp 服务,支持 `embed_documents` 和 `embed_query` 方法 +- **向量维度**: 自动检测模型维度(默认 2560),创建对应大小的 Qdrant 集合 + +#### 3. 向量存储 +- **Qdrant 集成**: 使用 `langchain_qdrant.QdrantVectorStore` 作为底层存储 +- **集合管理**: 自动创建/复用集合,支持 `force_recreate` 参数 +- **批量写入**: 支持 `batch_size` 参数,避免单次请求过大 + +#### 4. 文档存储 +- **PostgreSQL**: 使用 `PostgresDocStore` 持久化存储父块,支持异步连接池 +- **数据映射**: 通过 UUID 将子块与父块关联,检索时返回完整父块 ### 调用顺序 @@ -265,27 +216,42 @@ from rag_indexer.builder import IndexBuilder, SplitterType builder = IndexBuilder( collection_name="my_docs", splitter_type=SplitterType.PARENT_CHILD, - qdrant_url="http://localhost:6333", parent_chunk_size=1000, child_chunk_size=200, + docstore_conn_string="postgresql://user:pass@host:5432/db", ) ``` #### 2. 构建索引 ```python +import asyncio + # 方式A:从单个文件构建 -builder.build_from_file("/path/to/document.pdf") +async def main(): + count = await builder.build_from_file("/path/to/document.pdf") + print(f"已索引 {count} 个块") # 方式B:从目录批量构建 -builder.build_from_directory("/path/to/docs/") +async def main(): + count = await builder.build_from_directory("/path/to/docs/") + print(f"已索引 {count} 个块") + +asyncio.run(main()) ``` #### 3. 检索(获取完整父块上下文) ```python -# 检索时返回完整父块 -results = builder.search_with_parent_context("查询内容") +import asyncio + +async def main(): + # 检索时返回完整父块 + results = await builder.search_with_parent_context("查询内容", k=5) + for doc in results: + print(doc.page_content) + +asyncio.run(main()) ``` ### 检索流程 @@ -299,11 +265,16 @@ results = builder.search_with_parent_context("查询内容") --- ### 串联与触发方式 -在你的 LangGraph 系统外,创建一个执行脚本 `scripts/run_indexer.py`: +使用 `cli.py` 入口脚本: ```bash -# 终端执行,将本地的 PDF 手册刷入向量数据库 +# 设置环境变量 export QDRANT_URL="http://115.190.121.151:6333" -python scripts/run_indexer.py --file data/user_docs/tech_manual.pdf +export QDRANT_API_KEY="your-api-key" +export DB_URI="postgresql://postgres:password@host:5432/langgraph_db?sslmode=disable" + +# 执行索引构建 +python -m rag_indexer.cli --path data/user_docs/tech_manual.pdf ``` -这相当于系统后台的**“离线学习阶段”**,你可以随时挂载定时任务去扫描文件夹,增量更新知识库。 + +这相当于系统后台的**"离线学习阶段"**,你可以随时挂载定时任务去扫描文件夹,增量更新知识库。 diff --git a/rag_indexer/__init__.py b/rag_indexer/__init__.py index 78daf84..7d178ac 100644 --- a/rag_indexer/__init__.py +++ b/rag_indexer/__init__.py @@ -9,52 +9,52 @@ Offline RAG Indexer module. - 父文档存储(PostgreSQL) 示例用法: - >>> from rag_indexer import IndexBuilder, SplitterType + >>> from rag_indexer import IndexBuilder, IndexBuilderConfig, SplitterType >>> - >>> builder = IndexBuilder( + >>> config = IndexBuilderConfig( ... collection_name="my_docs", ... splitter_type=SplitterType.PARENT_CHILD, - ... qdrant_url="http://localhost:6333" ... ) + >>> builder = IndexBuilder(config) >>> - >>> builder.build_from_file("document.pdf") + >>> # 或直接传参(向后兼容) + >>> builder = IndexBuilder(collection_name="my_docs") + >>> + >>> await builder.build_from_file("document.pdf") """ +from .IndexBuilder import IndexBuilder, IndexBuilderConfig, DocstoreConfig from .loaders import DocumentLoader -from .splitters import ( - SplitterType, - get_splitter, - ParentChildSplitter, -) -from .embedders import LlamaCppEmbedder -from .vector_store import QdrantVectorStore -from .builder import IndexBuilder +from .splitters import SplitterType, get_splitter -# 导出存储相关类(从新的 store 包) -from .store import ( +# 从 rag_core 重新导出常用组件 +from rag_core import ( + LlamaCppEmbedder, + QdrantVectorStore, PostgresDocStore, create_docstore, ) - - __version__ = "2.0.0" __all__ = [ - # 核心类 - "DocumentLoader", + # 核心构建器与配置 "IndexBuilder", + "IndexBuilderConfig", + "DocstoreConfig", + + # 加载器 + "DocumentLoader", # 切分相关 "SplitterType", "get_splitter", - "ParentChildSplitter", - # 嵌入和向量存储 + # 嵌入与向量存储 "LlamaCppEmbedder", "QdrantVectorStore", - # 存储(新的 store 包) + # 文档存储 "PostgresDocStore", "create_docstore", -] +] \ No newline at end of file diff --git a/rag_indexer/builder.py b/rag_indexer/builder.py deleted file mode 100644 index 2a1c51e..0000000 --- a/rag_indexer/builder.py +++ /dev/null @@ -1,392 +0,0 @@ -""" -离线 RAG 索引构建核心流水线。 - -支持 LangChain 的 ParentDocumentRetriever 用于父子块切分。 -""" - -import asyncio -import logging -from pathlib import Path -from typing import List, Union, Optional, Tuple, Any -from dataclasses import dataclass - -from httpx import RemoteProtocolError -from langchain_core.documents import Document -from langchain_classic.retrievers import ParentDocumentRetriever -from langchain_core.stores import BaseStore -from langchain_text_splitters import RecursiveCharacterTextSplitter -from langchain_experimental.text_splitter import SemanticChunker - -from .loaders import DocumentLoader -from .splitters import SplitterType, get_splitter, ParentChildSplitter, SemanticChunkerAdapter -from .embedders import LlamaCppEmbedder -from .vector_store import QdrantVectorStore, QDRANT_URL, QDRANT_API_KEY -from .store import create_docstore - -logger = logging.getLogger(__name__) - - -@dataclass -class ParentChildConfig: - """父子块切分配置。""" - parent_chunk_size: int = 1000 - child_chunk_size: int = 200 - parent_chunk_overlap: int = 100 - child_chunk_overlap: int = 20 - search_k: int = 5 - docstore_path: Optional[str] = None - docstore_type: str = "local" - docstore_conn_string: Optional[str] = None - - -class IndexBuilder: - """RAG 索引构建主流水线。""" - - # 类型注解 - parent_splitter: "RecursiveCharacterTextSplitter" - child_splitter: Union["RecursiveCharacterTextSplitter", "SemanticChunker", "SemanticChunkerAdapter"] - docstore: Optional["BaseStore"] - _docstore_conn: Optional[str] - retriever: Optional["ParentDocumentRetriever"] - vector_store_obj: Any - - def __init__( - self, - collection_name: str = "rag_documents", - splitter_type: SplitterType = SplitterType.PARENT_CHILD, - docstore=None, - **splitter_kwargs, - ): - self.collection_name = collection_name - self.splitter_type = splitter_type - self.splitter_kwargs = splitter_kwargs - self.docstore = docstore # 从外部注入 - - # 组件 - self.loader = DocumentLoader() - self.embedder = LlamaCppEmbedder() - self.embeddings = self.embedder.as_langchain_embeddings() - - self.vector_store = QdrantVectorStore( - collection_name=collection_name, - embeddings=self.embeddings, - ) - - # 切分器(父子块单独处理) - if splitter_type != SplitterType.PARENT_CHILD: - if splitter_type == SplitterType.SEMANTIC: - splitter_kwargs["embeddings"] = self.embeddings - self.splitter = get_splitter(splitter_type, **splitter_kwargs) - else: - self.splitter = None - # 为父子块切分初始化 ParentDocumentRetriever - self._init_parent_child_retriever() - - def _init_parent_child_retriever(self, **kwargs): - """ - 初始化 ParentDocumentRetriever 用于父子块切分。 - - 支持动态语义切分与父子块策略结合: - - 父块使用递归切分(大块,提供上下文) - - 子块可以使用递归切分或语义切分(根据语义动态切分,提高检索精度) - - 替代自定义的 ParentChildSplitter 逻辑。 - """ - # 解析父子块配置参数 - parent_size = kwargs.get("parent_chunk_size", 1000) - child_size = kwargs.get("child_chunk_size", 200) - parent_overlap = kwargs.get("parent_chunk_overlap", kwargs.get("chunk_overlap", 100)) - child_overlap = kwargs.get("child_chunk_overlap", kwargs.get("chunk_overlap", 20)) - - # 子块切分器类型,默认为语义切分 - child_splitter_type = kwargs.get("child_splitter_type", SplitterType.SEMANTIC) - - # 定义父块切分器(始终使用递归切分) - self.parent_splitter = RecursiveCharacterTextSplitter( - chunk_size=parent_size, - chunk_overlap=parent_overlap, - ) - - # 定义子块切分器(根据类型选择) - if child_splitter_type == SplitterType.SEMANTIC: - self.child_splitter = get_splitter( - SplitterType.SEMANTIC, - embeddings=self.embeddings, - ) - logger.info(f"子块使用语义切分器") - else: - # 默认使用递归切分 - self.child_splitter = RecursiveCharacterTextSplitter( - chunk_size=child_size, - chunk_overlap=child_overlap, - ) - logger.info(f"子块使用递归切分器,块大小: {child_size},重叠: {child_overlap}") - - # 向量存储(用于子块) - self.vector_store_obj = self.vector_store.get_langchain_vectorstore() - - # 文档存储(用于父块) - if self.docstore is None: - # 如果没有外部注入 docstore,则使用 PostgreSQL 创建 - docstore_conn = kwargs.get("docstore_conn_string") - pool_config = kwargs.get("pool_config") - max_concurrency = kwargs.get("max_concurrency") - - # 使用 create_docstore 创建 PostgreSQL 存储 - self.docstore, self._docstore_conn = create_docstore( - connection_string=docstore_conn, - pool_config=pool_config, - max_concurrency=max_concurrency - ) - else: - # 使用外部注入的 docstore - self._docstore_conn = None - - # 创建检索器 - self.retriever = ParentDocumentRetriever( - vectorstore=self.vector_store_obj, - docstore=self.docstore, - child_splitter=self.child_splitter, # type: ignore - parent_splitter=self.parent_splitter, - search_kwargs={"k": kwargs.get("search_k", 5)}, - ) - logger.info(f"ParentDocumentRetriever 已初始化,父块大小: {parent_size},子块类型: {child_splitter_type}") - - async def build_from_file(self, file_path: Union[str, Path]) -> int: - logger.info("加载文件: %s", file_path) - documents = self.loader.load_file(file_path) - logger.info("已加载 %d 个文档", len(documents)) - return await self._process_documents(documents) - - async def build_from_directory(self, directory_path: Union[str, Path], recursive: bool = True) -> int: - logger.info("加载目录: %s (递归=%s)", directory_path, recursive) - documents = self.loader.load_directory(directory_path, recursive=recursive) - logger.info("已从目录加载 %d 个文档", len(documents)) - return await self._process_documents(documents) - - async def _process_documents(self, documents: List[Document]) -> int: - if not documents: - logger.warning("没有文档需要处理") - return 0 - - if self.splitter_type == SplitterType.PARENT_CHILD: - logger.info("使用 LangChain ParentDocumentRetriever") - - # 确保集合存在(用于子块) - self.vector_store.create_collection() - - # 分批处理,避免单次请求过大 - assert self.retriever is not None, "retriever 未初始化" - batch_size = 10 # 每次处理10个文档 - total = len(documents) - processed = 0 - - for i in range(0, total, batch_size): - batch = documents[i:i + batch_size] - max_retries = 3 - for attempt in range(max_retries): - try: - await self.retriever.aadd_documents(batch) - processed += len(batch) - logger.info(f"批次 {i//batch_size + 1}: 已处理 {processed}/{total}") - break - except (RemoteProtocolError, ConnectionError, OSError) as e: - if attempt == max_retries - 1: - raise - logger.warning(f"连接断开,重试 ({attempt+1}/{max_retries}): {e}") - self.vector_store.refresh_client() - await asyncio.sleep(1) - - logger.info( - "已使用 ParentDocumentRetriever 索引: " - f"共 {processed} 个父块" - ) - return processed - - else: - logger.info("使用 %s 切分文档", self.splitter_type) - # 当 splitter_type 不是 PARENT_CHILD 时,splitter 一定不为 None - assert self.splitter is not None, "splitter 未初始化" - chunks = self.splitter.split_documents(documents) - logger.info("已切分为 %d 个块", len(chunks)) - - self.vector_store.create_collection() - self.vector_store.add_documents(chunks) - return len(chunks) - - def get_collection_info(self): - return self.vector_store.get_collection_info() - - def search(self, query: str, k: int = 5) -> List[Document]: - """标准搜索 - 返回子块。""" - return self.vector_store.similarity_search(query, k=k) - - async def search_with_parent_context(self, query: str, k: int = 5) -> List[Document]: - """ - 带父块上下文的搜索 - 返回完整父块。 - - 这是使用父子块切分时的主要检索方法。 - """ - if self.splitter_type != SplitterType.PARENT_CHILD: - raise RuntimeError( - "search_with_parent_context 仅在 PARENT_CHILD 切分器下可用。" - "请使用 search() 进行标准检索。" - ) - assert self.retriever is not None, "retriever 未初始化" - return await self.retriever.ainvoke(query, config={"k": k}) # type: ignore - - async def retrieve(self, query: str, return_parent: bool = True) -> List[Document]: - """ - 统一检索接口。 - - Args: - query: 搜索查询 - return_parent: 如果为 True 且使用父子块切分,返回父块 - 如果为 False,始终返回子块 - - Returns: - 相关文档列表 - """ - if self.splitter_type == SplitterType.PARENT_CHILD and return_parent: - return await self.search_with_parent_context(query) - else: - return self.search(query) - - async def retrieve_with_fusion(self, query: str, llm: Any, num_queries: int = 3, k: int = 5, return_parent: bool = True) -> List[Document]: - """ - 使用 RAG-Fusion 进行检索(多路查询改写 + 倒数排名融合)。 - - 核心原理: - 1. 多路查询改写: 利用 LLM 将原始查询改写成多个不同表述 - 2. 倒数排名融合: 对每个改写查询的结果进行 RRF 融合,避免单一检索结果主导 - - Args: - query: 原始搜索查询 - llm: 语言模型实例(用于查询改写) - num_queries: 生成的查询数量 - k: 返回的文档数量 - return_parent: 如果为 True 且使用父子块切分,返回父块 - 如果为 False,始终返回子块 - - Returns: - 经过融合后的相关文档列表 - """ - from langchain.retrievers.multi_query import MultiQueryRetriever - from langchain.retrievers import EnsembleRetriever - - if self.splitter_type == SplitterType.PARENT_CHILD and return_parent: - # 使用 ParentDocumentRetriever 作为基础检索器 - assert self.retriever is not None, "retriever 未初始化" - base_retriever = self.retriever - else: - # 使用向量存储作为基础检索器 - base_retriever = self.vector_store.as_langchain_vectorstore().as_retriever(search_kwargs={"k": k * 2}) - - # 创建多路查询检索器 - multi_query_retriever = MultiQueryRetriever.from_llm( - retriever=base_retriever, - llm=llm, - include_original=True - ) - - # 设置自定义提示词以生成指定数量的查询 - from langchain_core.prompts import PromptTemplate - multi_query_retriever.llm_chain.prompt = PromptTemplate.from_template( - "你是一个专业的查询改写助手。你的任务是将用户的问题改写成 {num_queries} 个不同的版本。\n" - "这些版本应该从不同的角度、使用不同的关键词来表达相同或相关的意图。\n\n" - "原始问题: {question}\n\n" - "请生成 {num_queries} 个不同版本的查询,每个版本一行。\n" - "确保每个版本都是独立、完整的查询语句。\n\n" - "生成 {num_queries} 个查询:" - ) - - # 修改调用参数以包含 num_queries - original_ainvoke = multi_query_retriever.llm_chain.ainvoke - async def new_ainvoke(input_dict): - input_dict["num_queries"] = num_queries - return await original_ainvoke(input_dict) - multi_query_retriever.llm_chain.ainvoke = new_ainvoke - - # 执行检索 - documents = await multi_query_retriever.ainvoke(query) - - # 去重并限制数量 - seen_content = set() - unique_documents = [] - for doc in documents: - content = doc.page_content - if content not in seen_content: - seen_content.add(content) - unique_documents.append(doc) - if len(unique_documents) >= k: - break - - logger.info(f"RAG-Fusion 检索完成: 原始 {len(documents)} 个结果,去重后 {len(unique_documents)} 个结果") - return unique_documents - - def get_retriever(self) -> ParentDocumentRetriever: - """ - 直接获取 ParentDocumentRetriever 实例。 - - 适用于需要在 IndexBuilder 外部访问检索器的高级用例。 - """ - if self.splitter_type != SplitterType.PARENT_CHILD: - raise RuntimeError( - "get_retriever() 仅在 PARENT_CHILD 切分器下可用。" - "请使用 search() 或 search_with_parent_context() 进行标准检索。" - ) - assert self.retriever is not None, "retriever 未初始化" - return self.retriever - - def get_child_splitter(self) -> Union["RecursiveCharacterTextSplitter", "SemanticChunker", "SemanticChunkerAdapter"]: - """获取子块切分器以便重新配置。""" - if self.splitter_type != SplitterType.PARENT_CHILD: - return self.splitter # type: ignore - return self.child_splitter - - def get_parent_splitter(self) -> "RecursiveCharacterTextSplitter": - """获取父块切分器以便重新配置。""" - if self.splitter_type != SplitterType.PARENT_CHILD: - raise RuntimeError( - "父块切分器仅在 PARENT_CHILD 切分器下可用。" - ) - return self.parent_splitter - - def get_docstore(self) -> BaseStore: - """获取父块的文档存储。""" - if self.splitter_type != SplitterType.PARENT_CHILD: - raise RuntimeError( - "文档存储仅在 PARENT_CHILD 切分器下可用。" - ) - assert self.docstore is not None, "docstore 未初始化" - return self.docstore - - def get_docstore_path(self) -> Optional[str]: - """获取文档存储路径(已弃用,仅用于兼容性)。""" - if self.splitter_type != SplitterType.PARENT_CHILD: - raise RuntimeError( - "文档存储路径仅在 PARENT_CHILD 切分器下可用。" - ) - # PostgreSQL 存储没有 persist_path,返回 None - return None - - def close(self): - """关闭资源。""" - if self.docstore is not None and hasattr(self.docstore, "aclose"): - import asyncio - asyncio.get_event_loop().run_until_complete(self.docstore.aclose()) # type: ignore - logger.info("PostgreSQL 异步连接池已关闭") - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.close() - return False - - -# 需要导入 RecursiveCharacterTextSplitter -from langchain_text_splitters import RecursiveCharacterTextSplitter - - -# 示例用法已移除,请参考文档 diff --git a/rag_indexer/cli.py b/rag_indexer/cli.py index 63014b8..6942506 100755 --- a/rag_indexer/cli.py +++ b/rag_indexer/cli.py @@ -1,85 +1,77 @@ """ -Command-line interface for the RAG index builder. +简易命令行入口,使用默认配置构建 RAG 索引。 """ -import argparse import asyncio import logging import sys +from pathlib import Path -from rag_indexer.builder import IndexBuilder +from rag_indexer.IndexBuilder import IndexBuilder, IndexBuilderConfig from rag_indexer.splitters import SplitterType logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) +logger = logging.getLogger(__name__) -# 基础配置 +# 默认配置(所有连接参数从环境变量读取) COLLECTION_NAME = "rag_documents" -DB_URI = "postgresql://postgres:huang1998@115.190.121.151:5432/langgraph_db?sslmode=disable" +SPLITTER_TYPE = SplitterType.PARENT_CHILD +CHILD_SPLITTER_TYPE = SplitterType.SEMANTIC -# 基础切分参数 -CHUNK_SIZE = 500 -CHUNK_OVERLAP = 50 - -# 父子块切分参数 +# 父子块大小参数(可根据需要调整) PARENT_CHUNK_SIZE = 1000 -CHILD_CHUNK_SIZE = 200 PARENT_CHUNK_OVERLAP = 100 +CHILD_CHUNK_SIZE = 200 CHILD_CHUNK_OVERLAP = 20 +SEARCH_K = 5 -# 切分策略:basic(基础)、semantic(语义)、parent-child(父子块) -STRATEGY = "parent-child" -# 存储类型:postgres(PostgreSQL)、local(本地文件) -STORAGE_TYPE = "postgres" +def get_input_path() -> Path: + """从命令行参数获取输入路径,若未提供则使用默认示例路径。""" + if len(sys.argv) > 1: + return Path(sys.argv[1]) + # 默认测试路径(可按需修改) + return Path("data/user_docs/a.txt") async def main(): - # 使用固定策略 - splitter_type = SplitterType.PARENT_CHILD - child_splitter_type = SplitterType.SEMANTIC + input_path = get_input_path() + if not input_path.exists(): + logger.error("路径不存在: %s", input_path) + sys.exit(1) - splitter_kwargs = {} - - if splitter_type == SplitterType.RECURSIVE: - splitter_kwargs["chunk_size"] = CHUNK_SIZE - splitter_kwargs["chunk_overlap"] = CHUNK_OVERLAP - elif splitter_type == SplitterType.PARENT_CHILD: - splitter_kwargs["parent_chunk_size"] = PARENT_CHUNK_SIZE - splitter_kwargs["child_chunk_size"] = CHILD_CHUNK_SIZE - splitter_kwargs["parent_chunk_overlap"] = PARENT_CHUNK_OVERLAP - splitter_kwargs["child_chunk_overlap"] = CHILD_CHUNK_OVERLAP - splitter_kwargs["child_splitter_type"] = child_splitter_type - if STORAGE_TYPE == "postgres": - splitter_kwargs["docstore_conn_string"] = DB_URI - elif STORAGE_TYPE == "local": - splitter_kwargs["docstore_path"] = "./parent_docs" - else: - splitter_kwargs["docstore_conn_string"] = DB_URI - - builder = IndexBuilder( + # 构建配置(使用全部默认值) + config = IndexBuilderConfig( collection_name=COLLECTION_NAME, - splitter_type=splitter_type, - **splitter_kwargs + splitter_type=SPLITTER_TYPE, + parent_chunk_size=PARENT_CHUNK_SIZE, + parent_chunk_overlap=PARENT_CHUNK_OVERLAP, + child_chunk_size=CHILD_CHUNK_SIZE, + child_chunk_overlap=CHILD_CHUNK_OVERLAP, + child_splitter_type=CHILD_SPLITTER_TYPE, + search_k=SEARCH_K, + # docstore 默认使用 create_docstore 从环境变量读取 PostgreSQL 连接 ) - is_file=False - path="data/corpus/" + builder = IndexBuilder(config) + is_directory = input_path.is_dir() try: - if is_file: - chunk_count = await builder.build_from_file(path) - else: - chunk_count = await builder.build_from_directory(path, recursive=True) + async with builder: + if is_directory: + chunk_count = await builder.build_from_directory(input_path, recursive=True) + else: + chunk_count = await builder.build_from_file(input_path) - print(f"索引构建完成。共索引 {chunk_count} 个块") + print(f"\n索引构建完成。共索引 {chunk_count} 个块") info = builder.get_collection_info() print(f"集合 '{info['name']}' 包含 {info['vectors_count']} 个向量(维度:{info['vector_size']})") except Exception as e: - logging.exception(f"索引构建失败:{e}") + logger.exception("索引构建失败: %s", e) sys.exit(1) diff --git a/rag_indexer/loaders.py b/rag_indexer/loaders.py index d0c16c4..c5c6e33 100644 --- a/rag_indexer/loaders.py +++ b/rag_indexer/loaders.py @@ -3,19 +3,27 @@ """ import logging +import os from pathlib import Path -from typing import Any, Dict, List, Mapping, Optional, Union +from typing import Any, Dict, List, Optional, Union from langchain_core.documents import Document +from unstructured.documents.elements import Element from unstructured.partition.auto import partition logger = logging.getLogger(__name__) +# 模块加载时设置一次环境变量,避免重复设置 +os.environ.setdefault("UNSTRUCTURED_LANGUAGE_CHECKS", "false") + class DocumentLoader: """从各种文件格式加载文档。""" - SUPPORTED_EXTENSIONS = {".pdf", ".docx", ".doc", ".txt", ".md", ".html", ".pptx", ".xlsx", ".json"} + SUPPORTED_EXTENSIONS = { + ".pdf", ".docx", ".doc", ".txt", ".md", + ".html", ".pptx", ".xlsx", ".json" + } def __init__( self, @@ -32,13 +40,11 @@ class DocumentLoader: extract_images: 是否提取 PDF 中的图片 strategy: 解析策略 (auto, fast, hi_res, ocr_only) ocr_languages: OCR 语言列表,如 ['chi_sim', 'eng'] - languages: 文档主语言,如 ['zh'] + languages: 文档主语言,如 ['zh'](主要用于非 OCR 场景) include_page_breaks: 是否包含分页符 - pdf_infer_table_structure: 是否识别表格结构 (需 hi_res 策略) + pdf_infer_table_structure: 是否识别表格结构(需 hi_res 策略) partition_kwargs: 额外的 partition 参数字典(高级定制) """ - import os - os.environ["UNSTRUCTURED_LANGUAGE_CHECKS"] = "false" self.extract_images = extract_images self.strategy = strategy self.ocr_languages = ocr_languages or ["chi_sim", "eng"] @@ -47,6 +53,52 @@ class DocumentLoader: self.pdf_infer_table_structure = pdf_infer_table_structure self.partition_kwargs = partition_kwargs or {} + def _build_partition_kwargs(self, file_path: Path) -> Dict[str, Any]: + """根据文件类型构建 partition 的参数。""" + kwargs: Dict[str, Any] = { + "include_page_breaks": self.include_page_breaks, + } + + suffix = file_path.suffix.lower() + + # PDF 专用参数 + if suffix == ".pdf": + kwargs.update({ + "strategy": self.strategy, + "ocr_languages": self.ocr_languages, + "extract_images_in_pdf": self.extract_images, + "pdf_infer_table_structure": self.pdf_infer_table_structure, + }) + + # 所有文件适用的语言参数 + if self.languages: + kwargs["languages"] = self.languages + + # 用户自定义参数覆盖默认值 + kwargs.update(self.partition_kwargs) + + return kwargs + + def _element_to_document(self, element: Element, file_path: Path) -> Optional[Document]: + """将单个 Element 转换为 Document,同时保留关键元数据。""" + text = getattr(element, "text", "") + if not text or not text.strip(): + return None + + # 提取 unstructured 提供的元数据(根据实际需要选择) + metadata = { + "source": str(file_path), + "file_name": file_path.name, + "file_type": file_path.suffix.lower(), + # 以下元数据来自 Element 对象,可能为 None + "page_number": getattr(getattr(element, "metadata", None), "page_number", None), + "category": getattr(getattr(element, "metadata", None), "category", None), + } + # 过滤掉值为 None 的元数据 + metadata = {k: v for k, v in metadata.items() if v is not None} + + return Document(page_content=text, metadata=metadata) + def load_file(self, file_path: Union[str, Path]) -> List[Document]: """将单个文件加载为 LangChain Document 对象。""" file_path = Path(file_path).resolve() @@ -59,68 +111,58 @@ class DocumentLoader: f"不支持的文件扩展名: {suffix}。支持的格式: {self.SUPPORTED_EXTENSIONS}" ) - # 根据文件类型动态调整参数 - extra_kwargs = {} - if suffix == ".pdf": - extra_kwargs["strategy"] = self.strategy - extra_kwargs["ocr_languages"] = self.ocr_languages - extra_kwargs["extract_images_in_pdf"] = self.extract_images - extra_kwargs["pdf_infer_table_structure"] = self.pdf_infer_table_structure - - # languages 参数适用于所有文件类型 - if self.languages: - extra_kwargs["languages"] = self.languages - - extra_kwargs["include_page_breaks"] = self.include_page_breaks + kwargs = self._build_partition_kwargs(file_path) - # 合并用户自定义的额外参数(优先级最高) - extra_kwargs.update(self.partition_kwargs) - - # 使用 unstructured 解析 - elements = partition( - filename=str(file_path), - - **extra_kwargs - ) + try: + elements = partition(filename=str(file_path), **kwargs) + except Exception as e: + logger.exception("解析文件 %s 失败", file_path) + raise RuntimeError(f"文件解析失败: {file_path}") from e documents = [] for elem in elements: - text = getattr(elem, "text", "") - if not text or not text.strip(): - continue - - # 基础元数据 - metadata = { - "source": str(file_path), - "file_name": file_path.name, - "file_type": suffix, - } - - documents.append(Document(page_content=text, metadata=metadata)) + doc = self._element_to_document(elem, file_path) + if doc: + documents.append(doc) if not documents: logger.warning("未从 %s 提取到文本内容", file_path) - return [] return documents def load_directory( - self, directory_path: Union[str, Path], recursive: bool = True + self, + directory_path: Union[str, Path], + recursive: bool = True, + fail_fast: bool = False ) -> List[Document]: - """从目录加载所有支持的文件。""" + """ + 从目录加载所有支持的文件。 + + Args: + directory_path: 目录路径 + recursive: 是否递归子目录 + fail_fast: 遇到第一个失败时是否立即抛出异常 + """ directory_path = Path(directory_path).resolve() if not directory_path.is_dir(): raise NotADirectoryError(f"不是目录: {directory_path}") - all_documents = [] + all_documents: List[Document] = [] pattern = "**/*" if recursive else "*" for file_path in directory_path.glob(pattern): - if file_path.is_file() and file_path.suffix.lower() in self.SUPPORTED_EXTENSIONS: - try: - docs = self.load_file(file_path) - all_documents.extend(docs) - except Exception as e: - logger.error("加载 %s 失败: %s", file_path, e) + if not file_path.is_file(): + continue + if file_path.suffix.lower() not in self.SUPPORTED_EXTENSIONS: + continue + + try: + docs = self.load_file(file_path) + all_documents.extend(docs) + except Exception as e: + logger.error("加载 %s 失败: %s", file_path, e) + if fail_fast: + raise return all_documents \ No newline at end of file diff --git a/rag_indexer/splitters.py b/rag_indexer/splitters.py index 45874d3..006e8ab 100644 --- a/rag_indexer/splitters.py +++ b/rag_indexer/splitters.py @@ -3,7 +3,8 @@ """ from enum import Enum -from typing import List, Optional +from typing import List, Optional, Tuple, Dict, Any +from dataclasses import dataclass, field from langchain_core.documents import Document from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter @@ -16,68 +17,195 @@ class SplitterType(str, Enum): PARENT_CHILD = "parent_child" -def get_splitter(splitter_type: SplitterType, **kwargs): - """工厂函数,创建文本切分器。""" - if splitter_type == SplitterType.RECURSIVE: - chunk_size = kwargs.get("chunk_size", 500) - chunk_overlap = kwargs.get("chunk_overlap", 50) - return RecursiveCharacterTextSplitter( - chunk_size=chunk_size, - chunk_overlap=chunk_overlap, - separators=["\n\n", "\n", "。", "!", "?", " ", ""], - ) - elif splitter_type == SplitterType.SEMANTIC: - embeddings = kwargs.pop("embeddings", None) - if embeddings is None: - raise ValueError("语义切分器需要提供 'embeddings' 参数") - return SemanticChunkerAdapter(embeddings=embeddings, **kwargs) - else: - raise ValueError(f"不支持的切分器类型: {splitter_type}") +# ---------- 配置数据类,统一参数 ---------- +@dataclass +class RecursiveSplitterConfig: + """递归字符切分器配置""" + chunk_size: int = 500 + chunk_overlap: int = 50 + separators: List[str] = field(default_factory=lambda: ["\n\n", "\n", "。", "!", "?", " ", ""]) + keep_separator: bool = True + strip_whitespace: bool = True +@dataclass +class SemanticSplitterConfig: + """语义切分器配置,仅包含 SemanticChunker 支持的参数。""" + embeddings: Any + buffer_size: int = 1 + add_start_index: bool = False + breakpoint_threshold_type: str = "percentile" + breakpoint_threshold_amount: Optional[float] = None + number_of_chunks: Optional[int] = None + sentence_split_regex: str = r"(?<=[.?!。?!])\s+" + min_chunk_size: int = 100 + +@dataclass +class ParentChildSplitterConfig: + """父子切分器配置""" + embeddings: Any # 子块语义切分所需 + parent_chunk_size: int = 1000 + parent_chunk_overlap: int = 100 + child_buffer_size: int = 1 + child_breakpoint_threshold_type: str = "percentile" + child_breakpoint_threshold_amount: Optional[float] = None + child_min_chunk_size: int = 100 + child_max_chunk_size: Optional[int] = 200 + + +# ---------- 适配器:让 SemanticChunker 实现 TextSplitter 接口 ---------- class SemanticChunkerAdapter(TextSplitter): - """将 SemanticChunker 适配为 TextSplitter 接口。""" + """将 SemanticChunker 适配为 LangChain TextSplitter 接口。""" - def __init__(self, embeddings, **kwargs): + def __init__(self, config: SemanticSplitterConfig, **kwargs): super().__init__(**kwargs) - chunk_size = kwargs.pop("chunk_size", None) - chunk_overlap = kwargs.pop("chunk_overlap", None) - self._chunker = SemanticChunker(embeddings=embeddings, **kwargs) + self._config = config + self._chunker = SemanticChunker( + embeddings=config.embeddings, + buffer_size=config.buffer_size, + add_start_index=config.add_start_index, + breakpoint_threshold_type=config.breakpoint_threshold_type, + breakpoint_threshold_amount=config.breakpoint_threshold_amount, + number_of_chunks=config.number_of_chunks, + sentence_split_regex=config.sentence_split_regex, + min_chunk_size=config.min_chunk_size, + ) def split_text(self, text: str) -> List[str]: return self._chunker.split_text(text) + def split_documents(self, documents: List[Document]) -> List[Document]: + result = [] + for doc in documents: + chunks = self.split_text(doc.page_content) + for i, chunk in enumerate(chunks): + result.append(Document( + page_content=chunk, + metadata={**doc.metadata, "chunk_index": i} + )) + return result + +# ---------- 工厂函数,统一创建切分器 ---------- +def get_splitter(splitter_type: SplitterType, **kwargs) -> TextSplitter: + """ + 根据类型创建切分器。 + 支持传入配置对象或直接参数。 + """ + if splitter_type == SplitterType.RECURSIVE: + config = RecursiveSplitterConfig( + chunk_size=kwargs.get("chunk_size", 500), + chunk_overlap=kwargs.get("chunk_overlap", 50), + separators=kwargs.get("separators", ["\n\n", "\n", "。", "!", "?", " ", ""]), + ) + return RecursiveCharacterTextSplitter( + chunk_size=config.chunk_size, + chunk_overlap=config.chunk_overlap, + separators=config.separators, + keep_separator=config.keep_separator, + strip_whitespace=config.strip_whitespace, + ) + + elif splitter_type == SplitterType.SEMANTIC: + embeddings = kwargs.get("embeddings") + if embeddings is None: + raise ValueError("语义切分器需要提供 'embeddings' 参数") + + if "config" in kwargs and isinstance(kwargs["config"], SemanticSplitterConfig): + config = kwargs["config"] + else: + # 过滤出 SemanticSplitterConfig 支持的字段 + config_kwargs = { + "embeddings": embeddings, + "buffer_size": kwargs.get("buffer_size", 1), + "breakpoint_threshold_type": kwargs.get("breakpoint_threshold_type", "percentile"), + "breakpoint_threshold_amount": kwargs.get("breakpoint_threshold_amount"), + "number_of_chunks": kwargs.get("number_of_chunks"), + "min_chunk_size": kwargs.get("min_chunk_size", 100), + } + config = SemanticSplitterConfig(**config_kwargs) + return SemanticChunkerAdapter(config) + + elif splitter_type == SplitterType.PARENT_CHILD: + # 父子切分器在 builder 中单独处理,不通过本函数创建 + raise ValueError("父子切分器应通过 IndexBuilder 创建,不支持 get_splitter 直接构建") + + else: + raise ValueError(f"不支持的切分器类型: {splitter_type}") + +# ---------- 父子切分器实现 ---------- class ParentChildSplitter: """ - 将文档切分为父块(大块)和子块(小块)。 - 子块用于索引检索,父块用于存储上下文。 + 将文档切分为父块(大块,用于上下文)和子块(小块,用于索引检索)。 + 内部维护父子块之间的映射关系。 """ - def __init__( - self, - parent_chunk_size: int = 1000, - child_chunk_size: int = 200, - parent_chunk_overlap: int = 100, - child_chunk_overlap: int = 20, - ): + def __init__(self, config: ParentChildSplitterConfig): + self.config = config + # 父块使用递归字符切分 self.parent_splitter = RecursiveCharacterTextSplitter( - chunk_size=parent_chunk_size, - chunk_overlap=parent_chunk_overlap, + chunk_size=config.parent_chunk_size, + chunk_overlap=config.parent_chunk_overlap, ) - self.child_splitter = RecursiveCharacterTextSplitter( - chunk_size=child_chunk_size, - chunk_overlap=child_chunk_overlap, + # 子块使用语义切分 + semantic_config = SemanticSplitterConfig( + embeddings=config.embeddings, + buffer_size=config.child_buffer_size, + breakpoint_threshold_type=config.child_breakpoint_threshold_type, + breakpoint_threshold_amount=config.child_breakpoint_threshold_amount, + min_chunk_size=config.child_min_chunk_size, ) + self.child_splitter = SemanticChunkerAdapter(semantic_config) - def split_documents(self, documents: List[Document]) -> tuple[List[Document], List[Document]]: + # 存储父子块映射关系(可选) + self.parent_to_children: Dict[str, List[str]] = {} + self.child_to_parent: Dict[str, str] = {} + + def split_documents(self, documents: List[Document]) -> Tuple[List[Document], List[Document]]: """ 返回: (父块列表, 子块列表) + 同时填充内部映射字典。 """ parent_chunks = self.parent_splitter.split_documents(documents) child_chunks = self.child_splitter.split_documents(documents) - # 将子块与父块 ID 关联(可选元数据) - # 在实际实现中,需要将每个子块映射到对应的父块 ID。 - return parent_chunks, child_chunks \ No newline at end of file + # 建立映射关系(简化示例:根据文本包含关系粗略匹配,实际需更精确的算法) + # 这里仅作示意,生产环境建议使用 embedding 相似度或精确子串定位 + self._build_mappings(parent_chunks, child_chunks) + + return parent_chunks, child_chunks + + def _build_mappings(self, parents: List[Document], children: List[Document]) -> None: + """ + 根据文本内容建立父子映射。 + 本方法为简化实现,实际使用时请替换为更可靠的匹配逻辑。 + """ + self.parent_to_children.clear() + self.child_to_parent.clear() + + # 为每个父块生成唯一 ID(若无则使用索引) + for p_idx, parent in enumerate(parents): + parent_id = parent.metadata.get("id", f"parent_{p_idx}") + parent.metadata["id"] = parent_id + self.parent_to_children[parent_id] = [] + + # 将每个子块分配给包含其文本的第一个父块 + for c_idx, child in enumerate(children): + child_id = child.metadata.get("id", f"child_{c_idx}") + child.metadata["id"] = child_id + for parent in parents: + if child.page_content in parent.page_content: + parent_id = parent.metadata["id"] + self.parent_to_children[parent_id].append(child_id) + self.child_to_parent[child_id] = parent_id + child.metadata["parent_id"] = parent_id + break + + def get_parent_for_child(self, child_id: str) -> Optional[str]: + """根据子块 ID 获取父块 ID""" + return self.child_to_parent.get(child_id) + + def get_children_for_parent(self, parent_id: str) -> List[str]: + """根据父块 ID 获取所有子块 ID""" + return self.parent_to_children.get(parent_id, []) \ No newline at end of file diff --git a/rag_indexer/test/reset_index.py b/rag_indexer/test/reset_index.py new file mode 100644 index 0000000..7c6a793 --- /dev/null +++ b/rag_indexer/test/reset_index.py @@ -0,0 +1,80 @@ +"""清理 RAG 索引数据。 + +用法: + python reset_index.py # 清理全部 + python reset_index.py --qdrant # 仅清理 Qdrant + python reset_index.py --postgres # 仅清理 PostgreSQL +""" + +import asyncio +import os +import argparse + +from dotenv import load_dotenv +load_dotenv() + +QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333") +QDRANT_API_KEY = os.getenv("QDRANT_API_KEY") +DB_URI = os.getenv("DB_URI", "postgresql://postgres:huang1998@115.190.121.151:5432/langgraph_db?sslmode=disable") +COLLECTION_NAME = "rag_documents" +TABLE_NAME = "parent_documents" + + +def clear_qdrant(): + """删除 Qdrant 集合。""" + from qdrant_client import QdrantClient + + print("清理 Qdrant...") + client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY) + + collections = client.get_collections().collections + if any(c.name == COLLECTION_NAME for c in collections): + client.delete_collection(COLLECTION_NAME) + print(f" 集合 '{COLLECTION_NAME}' 已删除") + else: + print(f" 集合 '{COLLECTION_NAME}' 不存在") + + +async def clear_postgres(): + """清空 PostgreSQL 表数据。""" + import asyncpg + + print("清理 PostgreSQL...") + conn = await asyncpg.connect(dsn=DB_URI) + + try: + exists = await conn.fetchval( + "SELECT EXISTS(SELECT 1 FROM information_schema.tables WHERE table_schema = 'public' AND table_name = $1)", + TABLE_NAME + ) + if exists: + count = await conn.fetchval(f"SELECT COUNT(*) FROM {TABLE_NAME}") + await conn.execute(f"DELETE FROM {TABLE_NAME}") + print(f" 表 '{TABLE_NAME}' 已清空,删除 {count} 条记录") + else: + print(f" 表 '{TABLE_NAME}' 不存在") + finally: + await conn.close() + + +async def main(): + parser = argparse.ArgumentParser(description="清理 RAG 索引数据") + parser.add_argument("--qdrant", action="store_true", help="仅清理 Qdrant") + parser.add_argument("--postgres", action="store_true", help="仅清理 PostgreSQL") + args = parser.parse_args() + + if not args.qdrant and not args.postgres: + args.qdrant = True + args.postgres = True + + if args.qdrant: + clear_qdrant() + + if args.postgres: + await clear_postgres() + + print("\n完成。运行 `python -m rag_indexer.cli` 重建索引") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/rag_indexer/test/test_inspect_vectors.py b/rag_indexer/test/test_inspect_vectors.py new file mode 100644 index 0000000..5e296c0 --- /dev/null +++ b/rag_indexer/test/test_inspect_vectors.py @@ -0,0 +1,63 @@ +"""检查 Qdrant 中存储的向量质量。""" + +import os +import sys +import numpy as np +from dotenv import load_dotenv +from qdrant_client import QdrantClient + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../..")) +from rag_core import LlamaCppEmbedder + +load_dotenv() + +QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333") +QDRANT_API_KEY = os.getenv("QDRANT_API_KEY") +COLLECTION_NAME = "rag_documents" + +client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY) +embedder = LlamaCppEmbedder() + +# 获取样本 +points, _ = client.scroll( + collection_name=COLLECTION_NAME, + limit=1, + with_vectors=True, + with_payload=True, +) + +if not points: + print(f"集合 '{COLLECTION_NAME}' 为空") + exit() + +sample = points[0] +raw_vec = sample.vector +if isinstance(raw_vec, dict): + stored_vec = list(raw_vec.values())[0] +elif isinstance(raw_vec, list): + stored_vec = raw_vec +else: + stored_vec = [] + +stored_payload = sample.payload or {} +stored_text = str(stored_payload.get("page_content", ""))[:200] + +print(f"内容预览:\n{stored_text}...\n") +print(f"向量维度: {len(stored_vec)}") # type: ignore +print(f"前5个值: {stored_vec[:5]}") # type: ignore +print(f"是否全零: {all(v == 0.0 for v in stored_vec)}") # type: ignore + +# 重新编码对比 +if stored_text: + new_vec = embedder.embed_query(stored_text) + similarity = np.dot(stored_vec, new_vec) / (np.linalg.norm(stored_vec) * np.linalg.norm(new_vec)) # type: ignore + print(f"\n重新编码前5个值: {new_vec[:5]}") + print(f"余弦相似度: {similarity:.4f}") + + if similarity < 0.8: + print("\n⚠️ 相似度过低,建议删除集合并重建索引") + else: + print("\n✅ 向量一致") +else: + print("\n⚠️ 样本无文本内容") diff --git a/rag_indexer/test/test_refactored.py b/rag_indexer/test/test_refactored.py new file mode 100644 index 0000000..ca681d9 --- /dev/null +++ b/rag_indexer/test/test_refactored.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python3 +""" +测试重构后的 IndexBuilder 和 RAGRetriever +""" + +import asyncio +import os +import sys + +# 添加项目根目录到 Python 路径 +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) + +from rag_indexer.IndexBuilder import IndexBuilder +from rag_indexer.splitters import SplitterType + +async def test_index_builder(): + """测试索引构建功能""" + print("测试索引构建功能...") + + # 创建 IndexBuilder 实例 + builder = IndexBuilder( + collection_name="test_collection", + splitter_type=SplitterType.PARENT_CHILD, + parent_chunk_size=1000, + child_chunk_size=200 + ) + + # 测试文档路径 + test_file = os.path.join(os.path.dirname(__file__), "..", "data", "corpus", "三国演义.txt") + + if os.path.exists(test_file): + # 构建索引 + print(f"正在为文件 {test_file} 构建索引...") + processed = await builder.build_from_file(test_file) + print(f"索引构建完成,处理了 {processed} 个文档") + + # 获取集合信息 + info = builder.get_collection_info() + print(f"集合信息: {info}") + else: + print(f"测试文件不存在: {test_file}") + + # 测试搜索功能 + print("\n测试搜索功能...") + try: + results = builder.search("吕布", k=3) + print(f"搜索结果数量: {len(results)}") + for i, result in enumerate(results): + print(f"\n结果 {i+1}:") + print(f"内容: {result.page_content[:100]}...") + except Exception as e: + print(f"搜索测试失败: {e}") + + # 测试带父块上下文的搜索 + print("\n测试带父块上下文的搜索...") + try: + results = await builder.search_with_parent_context("吕布", k=3) + print(f"搜索结果数量: {len(results)}") + for i, result in enumerate(results): + print(f"\n结果 {i+1}:") + print(f"内容: {result.page_content[:100]}...") + except Exception as e: + print(f"带父块上下文的搜索测试失败: {e}") + + # 测试统一检索接口 + print("\n测试统一检索接口...") + try: + # 返回父块 + results_parent = await builder.retrieve("吕布", return_parent=True) + print(f"返回父块的结果数量: {len(results_parent)}") + + # 返回子块 + results_child = await builder.retrieve("吕布", return_parent=False) + print(f"返回子块的结果数量: {len(results_child)}") + except Exception as e: + print(f"统一检索接口测试失败: {e}") + + # 关闭资源 + builder.close() + print("\n测试完成") + +if __name__ == "__main__": + asyncio.run(test_index_builder()) \ No newline at end of file diff --git a/rag_indexer/test/test_validate_index.py b/rag_indexer/test/test_validate_index.py new file mode 100644 index 0000000..072cd90 --- /dev/null +++ b/rag_indexer/test/test_validate_index.py @@ -0,0 +1,188 @@ +""" +验证 RAG 索引完整性。 + +检查 Qdrant 向量库、PostgreSQL 文档存储及检索功能。 +""" + +import asyncio +import os +import sys + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../..")) + +from dotenv import load_dotenv +load_dotenv() + +QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333") +QDRANT_API_KEY = os.getenv("QDRANT_API_KEY") +DB_URI = os.getenv("DB_URI", "postgresql://postgres:huang1998@115.190.121.151:5432/langgraph_db?sslmode=disable") +COLLECTION_NAME = "rag_documents" +TABLE_NAME = "parent_documents" + + +def check_qdrant(): + """检查 Qdrant 向量库。""" + from qdrant_client import QdrantClient + + print("=" * 60) + print("Qdrant 向量库") + print("=" * 60) + + client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY) + + # 集合列表 + collections = client.get_collections().collections + print(f"\n集合数: {len(collections)}") + for c in collections: + print(f" - {c.name}") + + # 目标集合信息 + if not any(c.name == COLLECTION_NAME for c in collections): + print(f"\n集合 '{COLLECTION_NAME}' 不存在") + return + + info = client.get_collection(COLLECTION_NAME) + print(f"\n集合 '{COLLECTION_NAME}':") + print(f" 状态: {info.status}") + print(f" 向量数: {info.points_count}") + + vectors_config = info.config.params.vectors + if isinstance(vectors_config, dict): + for name, vc in vectors_config.items(): + print(f" 向量 '{name}': 维度={vc.size}, 距离={vc.distance}") + else: + print(f" 向量维度: {vectors_config.size}") + + # 抽样查看 + print(f"\n前 3 个向量:") + points = client.scroll( + collection_name=COLLECTION_NAME, + limit=3, + with_payload=True, + with_vectors=False + ) + for i, point in enumerate(points[0]): + print(f"\n {i+1}. ID: {point.id}") + payload = point.payload or {} + print(f" 内容: {payload.get('page_content', '')[:100]}...") + + +async def check_postgres(): + """检查 PostgreSQL 文档存储。""" + import asyncpg + + print("\n" + "=" * 60) + print("PostgreSQL 文档存储") + print("=" * 60) + + conn = await asyncpg.connect(dsn=DB_URI) + + try: + # 表是否存在 + tables = await conn.fetch( + "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'" + ) + table_names = [t['table_name'] for t in tables] + + if TABLE_NAME not in table_names: + print(f"\n表 '{TABLE_NAME}' 不存在") + return + + # 统计 + count = await conn.fetchval(f"SELECT COUNT(*) FROM {TABLE_NAME}") + print(f"\n表 '{TABLE_NAME}': {count} 条记录") + + # 抽样 + print(f"\n前 3 个文档:") + rows = await conn.fetch( + f"SELECT key, value FROM {TABLE_NAME} ORDER BY key LIMIT 3" + ) + for i, row in enumerate(rows): + print(f"\n {i+1}. Key: {row['key']}") + val = row['value'] + if isinstance(val, dict) and 'page_content' in val: + print(f" 内容: {val['page_content'][:100]}...") + + # Key 前缀分布 + key_prefixes = await conn.fetch( + f""" + SELECT + CASE + WHEN key LIKE '%:%' THEN split_part(key, ':', 1) + ELSE 'no_prefix' + END AS prefix, + COUNT(*) AS cnt + FROM {TABLE_NAME} + GROUP BY prefix + ORDER BY cnt DESC + LIMIT 10 + """ + ) + print(f"\nKey 前缀分布:") + for row in key_prefixes: + print(f" {row['prefix']}: {row['cnt']}") + + finally: + await conn.close() + + +async def test_search(): + """测试检索功能。""" + from rag_indexer.IndexBuilder import IndexBuilder, IndexBuilderConfig + from rag_indexer.splitters import SplitterType + + print("\n" + "=" * 60) + print("检索测试") + print("=" * 60) + + # 使用配置对象初始化(与默认构建方式一致) + config = IndexBuilderConfig( + collection_name=COLLECTION_NAME, + splitter_type=SplitterType.PARENT_CHILD, + ) + builder = IndexBuilder(config) + + # 确保检索器已初始化 + if builder.retriever is None: + print("错误: 检索器未初始化,请检查切分策略") + return + + query = input("\n查询 (回车使用默认): ").strip() or "你好" + print(f"\n查询: {query}") + + # 标准检索(返回父块,因为 ParentDocumentRetriever 默认返回父块) + print("\n--- 标准检索 (返回父块) ---") + results = await builder.retriever.ainvoke(query) + for i, doc in enumerate(results): + content = doc.page_content[:200] if hasattr(doc, 'page_content') else str(doc)[:200] + print(f"\n {i+1}. {content}...") + if hasattr(doc, 'metadata'): + source = doc.metadata.get('source', '') + if source: + print(f" 来源: {source}") + + # 若需要仅返回子块,可以临时修改检索器的 search_type + # (注意:ParentDocumentRetriever 的 search_type 默认为 "similarity") + print("\n--- 检索子块 (通过修改检索器参数) ---") + # 创建一个新的检索器副本,设置为返回子块 + # 简单起见,直接调用 vectorstore 进行相似度搜索获取子块 + vectorstore = builder.vector_store.get_langchain_vectorstore() + sub_results = await vectorstore.asimilarity_search(query, k=3) + for i, doc in enumerate(sub_results): + content = doc.page_content[:200] if hasattr(doc, 'page_content') else str(doc)[:200] + print(f"\n {i+1}. {content}...") + if hasattr(doc, 'metadata'): + parent_id = doc.metadata.get('parent_id', '') + if parent_id: + print(f" 父块 ID: {parent_id}") + + +async def main(): + check_qdrant() + await check_postgres() + await test_search() + + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file