From cc8ef41ef9456b209abff5b8330200bf6a8ed11e Mon Sep 17 00:00:00 2001 From: root <953994191@qq.com> Date: Sun, 19 Apr 2026 15:01:40 +0800 Subject: [PATCH] =?UTF-8?q?RAG=E6=95=B0=E6=8D=AE=E5=BA=93=E7=94=9F?= =?UTF-8?q?=E6=88=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .env.docker | 11 + app/rag/requirements.txt | 23 -- docker/Dockerfile.backend | 11 + rag_indexer/README.md | 216 +++++++++++++++++- rag_indexer/__init__.py | 51 ++++- rag_indexer/builder.py | 337 +++++++++++++++++++--------- rag_indexer/cli.py | 115 +++++----- rag_indexer/docstore_manager.py | 142 ------------ rag_indexer/embedders.py | 82 ++++--- rag_indexer/example_parent_child.py | 124 ---------- rag_indexer/loaders.py | 81 +++++-- rag_indexer/splitters.py | 40 ++-- rag_indexer/store/__init__.py | 31 +++ rag_indexer/store/factory.py | 73 ++++++ rag_indexer/store/postgres.py | 249 ++++++++++++++++++++ rag_indexer/vector_store.py | 77 ++++--- requirement.txt | 3 +- 17 files changed, 1089 insertions(+), 577 deletions(-) delete mode 100644 app/rag/requirements.txt delete mode 100644 rag_indexer/docstore_manager.py delete mode 100644 rag_indexer/example_parent_child.py create mode 100644 rag_indexer/store/__init__.py create mode 100644 rag_indexer/store/factory.py create mode 100644 rag_indexer/store/postgres.py diff --git a/.env.docker b/.env.docker index c0bc275..49f53dc 100644 --- a/.env.docker +++ b/.env.docker @@ -64,3 +64,14 @@ API_URL=http://backend:8083/chat # 应用行为配置 # ----------------------------------------------------------------------------- MEMORY_SUMMARIZE_INTERVAL=10 + +# ----------------------------------------------------------------------------- +# unstructured 库 spaCy 模型配置 +# ----------------------------------------------------------------------------- +# 指定文档解析使用的语言: eng (英语) 或 zho (中文) +UNSTRUCTURED_LANGUAGE=zho + +# 指定 spaCy 模型名称(需与 UNSTRUCTURED_LANGUAGE 对应) +# eng -> en_core_web_sm +# zho -> zh_core_web_sm +SPACY_MODEL=zh_core_web_sm diff --git a/app/rag/requirements.txt b/app/rag/requirements.txt deleted file mode 100644 index 2e5c4a0..0000000 --- a/app/rag/requirements.txt +++ /dev/null @@ -1,23 +0,0 @@ -# RAG 系统依赖 -# 基础框架 -langchain>=0.1.0 -langchain-core>=0.1.0 -langchain-openai>=0.0.1 -langchain-qdrant>=0.1.0 - -# 用于 Cross-Encoder 重排序模型 -sentence-transformers>=2.2.0 - -# 用于 BM25 关键词混合检索 -rank-bm25>=0.2.2 - -# Qdrant 客户端 -qdrant-client>=1.6.0 - -# 可选的本地模型支持 -# vllm>=0.5.0 # 如果需要本地模型推理 -# transformers>=4.35.0 # 如果需要其他模型支持 - -# 开发依赖(测试用) -pytest>=7.0.0 -pytest-asyncio>=0.21.0 \ No newline at end of file diff --git a/docker/Dockerfile.backend b/docker/Dockerfile.backend index fb125d6..4c232e6 100644 --- a/docker/Dockerfile.backend +++ b/docker/Dockerfile.backend @@ -18,6 +18,10 @@ ENV QDRANT_COLLECTION_NAME=mem0_user_memories ENV MEMORY_SUMMARIZE_INTERVAL=10 ENV ENABLE_GRAPH_TRACE=false +# unstructured 库 spaCy 模型配置 +ENV UNSTRUCTURED_LANGUAGE=eng +ENV SPACY_MODEL=en_core_web_sm + # 日志配置 ENV LOG_LEVEL=WARNING ENV DEBUG=false @@ -28,6 +32,13 @@ ENV DEBUG=false COPY requirement.txt . RUN pip install --no-cache-dir -r requirement.txt +# ============================================================================= +# 预下载 spaCy 语言模型(避免容器启动时重复下载) +# ============================================================================= +RUN pip install --no-cache-dir spacy && \ + python -m spacy download en_core_web_sm && \ + python -m spacy download zh_core_web_sm + # ============================================================================= # 复制项目代码 (只复制必需的文件夹,避免依赖被忽略的目录) # ============================================================================= diff --git a/rag_indexer/README.md b/rag_indexer/README.md index c656a05..cb2df4c 100644 --- a/rag_indexer/README.md +++ b/rag_indexer/README.md @@ -51,10 +51,111 @@ graph TD - **核心思路**: 解决 RAG 领域经典的矛盾——检索时块越小越容易精确命中(去除噪声);但生成回答时,块越大越能给大模型提供充足的上下文背景。 - **实现指南**: - 使用 `langchain.retrievers` 中的 `ParentDocumentRetriever` 模块。 - - 在写入时,你需要同时准备一个底层的 `VectorStore` (即 Qdrant) 和一个 `BaseStore` (比如原生的 `InMemoryStore` 或 `Redis`)。 + - 在写入时,你需要同时准备一个底层的 `VectorStore` (即 Qdrant) 和一个 `BaseStore`。 + - **推荐方案**: 使用 `LocalFileStore` (默认) 或 `PostgresDocStore` 作为 docstore。 - 将两种不同的 `TextSplitter` 分别赋值给检索器的 `child_splitter` 和 `parent_splitter`,然后调用 `.add_documents()` 即可让系统自动完成映射。 -### Level 4: GraphRAG 与 多模态 (Graph & Multi-modal) +### Level 3.1: PostgreSQL DocStore 集成 +- **核心优势**: 利用 PostgreSQL 作为持久化存储,适合生产环境。使用同步连接池,避免异步复杂度。 +- **实现步骤**: + 1. **安装依赖**: `pip install psycopg2-binary` + 2. **配置连接**: 设置 `DB_URI` 环境变量或直接在代码中指定 PostgreSQL 连接字符串 + 3. **创建 docstore**: 使用 `PostgresDocStore` 类直接创建 + 4. **注入到 IndexBuilder**: 在创建 `IndexBuilder` 时通过 `docstore` 参数注入 + +- **使用示例**: + ```python + from rag_indexer.docstore_manager import PostgresDocStore + from rag_indexer.builder import IndexBuilder, SplitterType + + # 创建 PostgreSQL docstore + docstore = PostgresDocStore( + connection_string="postgresql://user:pass@host:5432/db", + table_name="parent_documents" + ) + + # 创建 IndexBuilder 并注入 docstore + builder = IndexBuilder( + collection_name="rag_documents", + splitter_type=SplitterType.PARENT_CHILD, + docstore=docstore, + parent_chunk_size=1000, + child_chunk_size=200, + ) + ``` + +### Level 3.2: 语义切分与父子块策略结合 +- **核心优势**: 结合语义切分的连贯性和父子块策略的层次化存储优势,实现更精准的检索和更丰富的上下文。 +- **实现原理**: + - **父块切分**: 使用递归字符切分创建大块(约1000词),提供完整的上下文背景 + - **子块切分**: 使用语义动态切分创建小块(约200词),根据语义连贯性动态切分,提高检索精度 + - **存储机制**: 子块向量存入Qdrant用于精准检索,父块内容存入PostgreSQL提供完整上下文 +- **使用示例**: + ```python + from rag_indexer.builder import IndexBuilder, SplitterType + + # 创建 IndexBuilder,结合语义切分与父子块策略 + builder = IndexBuilder( + collection_name="rag_documents", + splitter_type=SplitterType.PARENT_CHILD, + # 父子块配置 + parent_chunk_size=1000, + child_chunk_size=200, + # 子块使用语义切分 + child_splitter_type=SplitterType.SEMANTIC, + # PostgreSQL 存储配置 + docstore_conn_string="postgresql://user:pass@host:5432/db", + ) + ``` +- **配置参数**: + - `child_splitter_type`: 子块切分器类型,可选 `SplitterType.RECURSIVE`(默认)或 `SplitterType.SEMANTIC` + - 当使用语义切分时,系统会自动使用已配置的Embedding模型进行句子级相似度计算 + +### Level 4: RAG-Fusion (多路改写与倒数排名融合) +- **核心优势**: 通过大模型发散思维,将单一问题改写为多个相似问题,扩大搜索面,再利用数学统计算法合并结果,提高检索的全面性和准确性。 +- **实现原理**: + 1. **多路查询改写**: 利用LLM将原始查询改写成3-5个不同表述的查询,从不同角度表达相同意图 + 2. **倒数排名融合 (RRF)**: 对每个改写查询的结果进行RRF融合,公式为 $RRF\_score(d) = \sum_{q \in Q} \frac{1}{k + rank_q(d)}$,避免单一检索结果主导 + 3. **结果去重**: 对融合后的结果进行去重,确保返回的文档唯一 +- **使用示例**: + ```python + from rag_indexer.builder import IndexBuilder, SplitterType + from langchain_openai import OpenAI + + # 创建 IndexBuilder + builder = IndexBuilder( + collection_name="rag_documents", + splitter_type=SplitterType.PARENT_CHILD, + parent_chunk_size=1000, + child_chunk_size=200, + docstore_conn_string="postgresql://user:pass@host:5432/db", + ) + + # 创建语言模型用于查询改写 + llm = OpenAI( + openai_api_base="http://localhost:8000/v1", + openai_api_key="no-key-needed", + model_name="Qwen2.5-7B-Instruct", + temperature=0.3, + ) + + # 使用 RAG-Fusion 检索 + query = "如何申请项目资金?" + results = builder.retrieve_with_fusion( + query=query, + llm=llm, + num_queries=3, + k=5, + return_parent=True + ) + ``` +- **配置参数**: + - `llm`: 语言模型实例,用于查询改写 + - `num_queries`: 生成的查询数量,建议3-5个 + - `k`: 返回的文档数量 + - `return_parent`: 是否返回父块上下文 + +### Level 5: GraphRAG 与 多模态 (Graph & Multi-modal) - **核心算法**: LLM 实体关系抽取 (NER & Relation Extraction)。 - **核心思路**: 解决传统纯向量检索难以处理“跨文档复杂关系推理”的痛点(如:A公司的CEO是谁?他名下的B公司主要业务是什么?这种需要横跨多页 PDF 的跳跃性问题)。 - **实现指南**: @@ -63,7 +164,7 @@ graph TD --- -## � 所需依赖与安装 +## 所需依赖与安装 为了支持完整的文档解析和 Qdrant 写入,需要安装以下 Python 包: @@ -76,6 +177,12 @@ pip install unstructured pdf2image pdfminer.six # 用于语义分块 (可选) pip install langchain-experimental + +# 用于 PostgreSQL 存储 (可选,用于 Parent-Child 策略) +pip install psycopg2-binary + +# 用于 RAG-Fusion (可选,需要语言模型) +pip install langchain-openai ``` --- @@ -87,16 +194,109 @@ pip install langchain-experimental ```text rag_indexer/ ├── __init__.py -├── loaders.py # 负责调用 unstructured 解析不同类型文件 -├── splitters.py # 负责实现 Recursive、Semantic、Parent-Child 切分逻辑 -├── embedders.py # 封装本地 llama.cpp 交互的 Embedding 接口 -├── vector_store.py # 封装 Qdrant 写入、Upsert、Collection 初始化操作 -└── builder.py # 核心编排文件,将上述模块串联成 Pipeline +├── loaders.py # 负责调用 unstructured 解析不同类型文件 +├── splitters.py # 负责实现 Recursive、Semantic、Parent-Child 切分逻辑 +├── embedders.py # 封装本地 llama.cpp 交互的 Embedding 接口 +├── vector_store.py # 封装 Qdrant 写入、Upsert、Collection 初始化操作 +├── docstore_manager.py # 文档存储管理器,支持 LocalFileStore 和 PostgreSQL +└── builder.py # 核心编排文件,将上述模块串联成 Pipeline ``` --- +## 🔄 工作流程详解 +### 数据流向总览 + +``` + ┌─────────────────────────────────────────┐ + │ builder.py │ + │ IndexBuilder 入口 │ + └─────────────────┬───────────────────────┘ + │ + ┌─────────────────▼───────────────────────┐ + │ loaders.py │ + │ DocumentLoader.load_file() │ + │ → 返回 List[Document] │ + └─────────────────┬───────────────────────┘ + │ + ┌─────────────────▼───────────────────────┐ + │ ParentDocumentRetriever.add_documents()│ + │ ┌─────────────────────────────────┐ │ + │ │ parent_splitter (粗切) │ │ + │ │ 父块 ~1000 词 │ │ + │ └────────────┬────────────────────┘ │ + │ │ │ + │ ┌────────────▼────────────────────┐ │ + │ │ child_splitter (细切) │ │ + │ │ 子块 ~200 词 │ │ + │ └────────────┬────────────────────┘ │ + │ │ │ + │ ┌──────────┴──────────┐ │ + │ ▼ ▼ │ + │ 子块向量 父块原始内容 │ + │ │ │ │ + │ ▼ ▼ │ + │ ┌────────────┐ ┌─────────────────┐ │ + │ │vector_store│ │ docstore_manager│ │ + │ │ (Qdrant) │ │ (PostgreSQL) │ │ + │ └────────────┘ └─────────────────┘ │ + └─────────────────────────────────────────┘ +``` + +### 文件职责详解 + +| 文件 | 职责 | 关键类/函数 | +|------|------|------------| +| **builder.py** | 核心编排,负责串联整个流程 | `IndexBuilder` | +| **loaders.py** | 解析各种文档格式(PDF、Word、TXT等) | `DocumentLoader` | +| **splitters.py** | 文本切分策略(Recursive/Semantic/Parent-Child) | `SplitterType`, `get_splitter()` | +| **embedders.py** | 向量化(封装 llama.cpp embedding 接口) | `LlamaCppEmbedder` | +| **vector_store.py** | Qdrant 向量数据库操作 | `QdrantVectorStore` | +| **docstore_manager.py** | 父文档存储(PostgreSQL/本地文件) | `PostgresDocStore`, `get_docstore()` | + +### 调用顺序 + +#### 1. 创建 IndexBuilder(入口) + +```python +from rag_indexer.builder import IndexBuilder, SplitterType + +builder = IndexBuilder( + collection_name="my_docs", + splitter_type=SplitterType.PARENT_CHILD, + qdrant_url="http://localhost:6333", + parent_chunk_size=1000, + child_chunk_size=200, +) +``` + +#### 2. 构建索引 + +```python +# 方式A:从单个文件构建 +builder.build_from_file("/path/to/document.pdf") + +# 方式B:从目录批量构建 +builder.build_from_directory("/path/to/docs/") +``` + +#### 3. 检索(获取完整父块上下文) + +```python +# 检索时返回完整父块 +results = builder.search_with_parent_context("查询内容") +``` + +### 检索流程 + +``` +1. vector_store.similarity_search() → 从 Qdrant 找到相关子块 +2. retriever.get_relevant_documents() → 根据子块 ID 获取对应父块 +3. 返回完整父块给用户 +``` + +--- ### 串联与触发方式 在你的 LangGraph 系统外,创建一个执行脚本 `scripts/run_indexer.py`: diff --git a/rag_indexer/__init__.py b/rag_indexer/__init__.py index 56905c2..78daf84 100644 --- a/rag_indexer/__init__.py +++ b/rag_indexer/__init__.py @@ -1,25 +1,60 @@ """ Offline RAG Indexer module. + +提供完整的离线索引构建功能,包括: +- 文档加载(PDF、Word、TXT 等) +- 文本切分(递归、语义、父子块) +- 向量嵌入(支持 llama.cpp) +- 向量存储(Qdrant) +- 父文档存储(PostgreSQL) + +示例用法: + >>> from rag_indexer import IndexBuilder, SplitterType + >>> + >>> builder = IndexBuilder( + ... collection_name="my_docs", + ... splitter_type=SplitterType.PARENT_CHILD, + ... qdrant_url="http://localhost:6333" + ... ) + >>> + >>> builder.build_from_file("document.pdf") """ from .loaders import DocumentLoader from .splitters import ( - RecursiveSplitter, - SemanticSplitter, - ParentChildSplitter, SplitterType, + get_splitter, + ParentChildSplitter, ) from .embedders import LlamaCppEmbedder from .vector_store import QdrantVectorStore from .builder import IndexBuilder +# 导出存储相关类(从新的 store 包) +from .store import ( + PostgresDocStore, + create_docstore, +) + + + +__version__ = "2.0.0" + __all__ = [ + # 核心类 "DocumentLoader", - "RecursiveSplitter", - "SemanticSplitter", - "ParentChildSplitter", + "IndexBuilder", + + # 切分相关 "SplitterType", + "get_splitter", + "ParentChildSplitter", + + # 嵌入和向量存储 "LlamaCppEmbedder", "QdrantVectorStore", - "IndexBuilder", -] \ No newline at end of file + + # 存储(新的 store 包) + "PostgresDocStore", + "create_docstore", +] diff --git a/rag_indexer/builder.py b/rag_indexer/builder.py index d680c5a..2a1c51e 100644 --- a/rag_indexer/builder.py +++ b/rag_indexer/builder.py @@ -1,56 +1,68 @@ """ -Core pipeline builder for offline RAG index construction. +离线 RAG 索引构建核心流水线。 -Now supports LangChain's ParentDocumentRetriever for parent-child chunking. +支持 LangChain 的 ParentDocumentRetriever 用于父子块切分。 """ +import asyncio import logging from pathlib import Path -from typing import List, Union, Optional, Tuple +from typing import List, Union, Optional, Tuple, Any from dataclasses import dataclass +from httpx import RemoteProtocolError from langchain_core.documents import Document -from langchain.retrievers import ParentDocumentRetriever -from langchain.storage import LocalFileStore, BaseStore +from langchain_classic.retrievers import ParentDocumentRetriever +from langchain_core.stores import BaseStore +from langchain_text_splitters import RecursiveCharacterTextSplitter +from langchain_experimental.text_splitter import SemanticChunker from .loaders import DocumentLoader -from .splitters import SplitterType, get_splitter, ParentChildSplitter +from .splitters import SplitterType, get_splitter, ParentChildSplitter, SemanticChunkerAdapter from .embedders import LlamaCppEmbedder -from .vector_store import QdrantVectorStore -from .docstore_manager import get_docstore, PostgresDocStore, create_docstore +from .vector_store import QdrantVectorStore, QDRANT_URL, QDRANT_API_KEY +from .store import create_docstore logger = logging.getLogger(__name__) @dataclass class ParentChildConfig: - """Configuration for parent-child splitting.""" + """父子块切分配置。""" parent_chunk_size: int = 1000 child_chunk_size: int = 200 parent_chunk_overlap: int = 100 child_chunk_overlap: int = 20 search_k: int = 5 - docstore_path: str = None + docstore_path: Optional[str] = None docstore_type: str = "local" - docstore_conn_string: str = None + docstore_conn_string: Optional[str] = None class IndexBuilder: - """Main pipeline for RAG index construction.""" + """RAG 索引构建主流水线。""" + + # 类型注解 + parent_splitter: "RecursiveCharacterTextSplitter" + child_splitter: Union["RecursiveCharacterTextSplitter", "SemanticChunker", "SemanticChunkerAdapter"] + docstore: Optional["BaseStore"] + _docstore_conn: Optional[str] + retriever: Optional["ParentDocumentRetriever"] + vector_store_obj: Any def __init__( self, collection_name: str = "rag_documents", - qdrant_url: str = None, - splitter_type: SplitterType = SplitterType.RECURSIVE, + splitter_type: SplitterType = SplitterType.PARENT_CHILD, + docstore=None, **splitter_kwargs, ): self.collection_name = collection_name - self.qdrant_url = qdrant_url self.splitter_type = splitter_type self.splitter_kwargs = splitter_kwargs + self.docstore = docstore # 从外部注入 - # Components + # 组件 self.loader = DocumentLoader() self.embedder = LlamaCppEmbedder() self.embeddings = self.embedder.as_langchain_embeddings() @@ -58,104 +70,145 @@ class IndexBuilder: self.vector_store = QdrantVectorStore( collection_name=collection_name, embeddings=self.embeddings, - qdrant_url=qdrant_url, ) - # Splitter (except parent-child which is handled separately) + # 切分器(父子块单独处理) if splitter_type != SplitterType.PARENT_CHILD: if splitter_type == SplitterType.SEMANTIC: splitter_kwargs["embeddings"] = self.embeddings self.splitter = get_splitter(splitter_type, **splitter_kwargs) else: self.splitter = None - # Initialize ParentDocumentRetriever for parent-child splitting + # 为父子块切分初始化 ParentDocumentRetriever self._init_parent_child_retriever() def _init_parent_child_retriever(self, **kwargs): """ - Initialize ParentDocumentRetriever for parent-child chunking. + 初始化 ParentDocumentRetriever 用于父子块切分。 - This replaces the custom ParentChildSplitter logic. + 支持动态语义切分与父子块策略结合: + - 父块使用递归切分(大块,提供上下文) + - 子块可以使用递归切分或语义切分(根据语义动态切分,提高检索精度) + + 替代自定义的 ParentChildSplitter 逻辑。 """ - # Parse kwargs for parent-child config + # 解析父子块配置参数 parent_size = kwargs.get("parent_chunk_size", 1000) child_size = kwargs.get("child_chunk_size", 200) parent_overlap = kwargs.get("parent_chunk_overlap", kwargs.get("chunk_overlap", 100)) child_overlap = kwargs.get("child_chunk_overlap", kwargs.get("chunk_overlap", 20)) + + # 子块切分器类型,默认为语义切分 + child_splitter_type = kwargs.get("child_splitter_type", SplitterType.SEMANTIC) - # Define splitters + # 定义父块切分器(始终使用递归切分) self.parent_splitter = RecursiveCharacterTextSplitter( chunk_size=parent_size, chunk_overlap=parent_overlap, ) - self.child_splitter = RecursiveCharacterTextSplitter( - chunk_size=child_size, - chunk_overlap=child_overlap, - ) + + # 定义子块切分器(根据类型选择) + if child_splitter_type == SplitterType.SEMANTIC: + self.child_splitter = get_splitter( + SplitterType.SEMANTIC, + embeddings=self.embeddings, + ) + logger.info(f"子块使用语义切分器") + else: + # 默认使用递归切分 + self.child_splitter = RecursiveCharacterTextSplitter( + chunk_size=child_size, + chunk_overlap=child_overlap, + ) + logger.info(f"子块使用递归切分器,块大小: {child_size},重叠: {child_overlap}") - # Vector store (for child chunks) + # 向量存储(用于子块) self.vector_store_obj = self.vector_store.get_langchain_vectorstore() - # Document store (for parent chunks) - docstore_path = kwargs.get("docstore_path") - docstore_type = kwargs.get("docstore_type", "local") - docstore_conn = kwargs.get("docstore_conn_string") - - if docstore_type == "postgres" and docstore_conn: - self.docstore = PostgresDocStore(docstore_conn) - self._docstore_conn = docstore_conn + # 文档存储(用于父块) + if self.docstore is None: + # 如果没有外部注入 docstore,则使用 PostgreSQL 创建 + docstore_conn = kwargs.get("docstore_conn_string") + pool_config = kwargs.get("pool_config") + max_concurrency = kwargs.get("max_concurrency") + + # 使用 create_docstore 创建 PostgreSQL 存储 + self.docstore, self._docstore_conn = create_docstore( + connection_string=docstore_conn, + pool_config=pool_config, + max_concurrency=max_concurrency + ) else: - self.docstore = get_docstore(docstore_path) + # 使用外部注入的 docstore self._docstore_conn = None - # Create retriever + # 创建检索器 self.retriever = ParentDocumentRetriever( vectorstore=self.vector_store_obj, docstore=self.docstore, - child_splitter=self.child_splitter, + child_splitter=self.child_splitter, # type: ignore parent_splitter=self.parent_splitter, search_kwargs={"k": kwargs.get("search_k", 5)}, ) + logger.info(f"ParentDocumentRetriever 已初始化,父块大小: {parent_size},子块类型: {child_splitter_type}") - def build_from_file(self, file_path: Union[str, Path]) -> int: - logger.info("Loading file: %s", file_path) + async def build_from_file(self, file_path: Union[str, Path]) -> int: + logger.info("加载文件: %s", file_path) documents = self.loader.load_file(file_path) - logger.info("Loaded %d documents", len(documents)) - return self._process_documents(documents) + logger.info("已加载 %d 个文档", len(documents)) + return await self._process_documents(documents) - def build_from_directory(self, directory_path: Union[str, Path], recursive: bool = True) -> int: - logger.info("Loading directory: %s (recursive=%s)", directory_path, recursive) + async def build_from_directory(self, directory_path: Union[str, Path], recursive: bool = True) -> int: + logger.info("加载目录: %s (递归=%s)", directory_path, recursive) documents = self.loader.load_directory(directory_path, recursive=recursive) - logger.info("Loaded %d documents from directory", len(documents)) - return self._process_documents(documents) + logger.info("已从目录加载 %d 个文档", len(documents)) + return await self._process_documents(documents) - def _process_documents(self, documents: List[Document]) -> int: + async def _process_documents(self, documents: List[Document]) -> int: if not documents: - logger.warning("No documents to process") + logger.warning("没有文档需要处理") return 0 if self.splitter_type == SplitterType.PARENT_CHILD: - logger.info("Using LangChain ParentDocumentRetriever") + logger.info("使用 LangChain ParentDocumentRetriever") - # Ensure collection exists for child chunks + # 确保集合存在(用于子块) self.vector_store.create_collection() - # Use ParentDocumentRetriever to add documents - # This automatically handles parent-child splitting, mapping, and retrieval - self.retriever.add_documents(documents) + # 分批处理,避免单次请求过大 + assert self.retriever is not None, "retriever 未初始化" + batch_size = 10 # 每次处理10个文档 + total = len(documents) + processed = 0 + + for i in range(0, total, batch_size): + batch = documents[i:i + batch_size] + max_retries = 3 + for attempt in range(max_retries): + try: + await self.retriever.aadd_documents(batch) + processed += len(batch) + logger.info(f"批次 {i//batch_size + 1}: 已处理 {processed}/{total}") + break + except (RemoteProtocolError, ConnectionError, OSError) as e: + if attempt == max_retries - 1: + raise + logger.warning(f"连接断开,重试 ({attempt+1}/{max_retries}): {e}") + self.vector_store.refresh_client() + await asyncio.sleep(1) - # Log estimated chunk counts - estimated_parent_chunks = len(documents) * (self.parent_splitter._chunk_size // self.child_splitter._chunk_size) logger.info( - "Indexed with ParentDocumentRetriever: " - f"~{len(documents)} parent chunks, ~{estimated_parent_chunks} child chunks" + "已使用 ParentDocumentRetriever 索引: " + f"共 {processed} 个父块" ) - return len(documents) + return processed else: - logger.info("Splitting documents using %s", self.splitter_type) + logger.info("使用 %s 切分文档", self.splitter_type) + # 当 splitter_type 不是 PARENT_CHILD 时,splitter 一定不为 None + assert self.splitter is not None, "splitter 未初始化" chunks = self.splitter.split_documents(documents) - logger.info("Split into %d chunks", len(chunks)) + logger.info("已切分为 %d 个块", len(chunks)) self.vector_store.create_collection() self.vector_store.add_documents(chunks) @@ -165,90 +218,164 @@ class IndexBuilder: return self.vector_store.get_collection_info() def search(self, query: str, k: int = 5) -> List[Document]: - """Standard search - returns child chunks.""" + """标准搜索 - 返回子块。""" return self.vector_store.similarity_search(query, k=k) - def search_with_parent_context(self, query: str, k: int = 5) -> List[Document]: + async def search_with_parent_context(self, query: str, k: int = 5) -> List[Document]: """ - Search with parent context - returns full parent chunks. + 带父块上下文的搜索 - 返回完整父块。 - This is the main retrieval method when using parent-child splitting. + 这是使用父子块切分时的主要检索方法。 """ if self.splitter_type != SplitterType.PARENT_CHILD: raise RuntimeError( - "search_with_parent_context only available with PARENT_CHILD splitter. " - "Use search() for standard retrieval." + "search_with_parent_context 仅在 PARENT_CHILD 切分器下可用。" + "请使用 search() 进行标准检索。" ) - return self.retriever.get_relevant_documents(query, k=k) + assert self.retriever is not None, "retriever 未初始化" + return await self.retriever.ainvoke(query, config={"k": k}) # type: ignore - def retrieve(self, query: str, return_parent: bool = True) -> List[Document]: + async def retrieve(self, query: str, return_parent: bool = True) -> List[Document]: """ - Unified retrieval interface. + 统一检索接口。 Args: - query: Search query - return_parent: If True and using parent-child splitter, return parent chunks - If False, always return child chunks + query: 搜索查询 + return_parent: 如果为 True 且使用父子块切分,返回父块 + 如果为 False,始终返回子块 Returns: - List of relevant documents + 相关文档列表 """ if self.splitter_type == SplitterType.PARENT_CHILD and return_parent: - return self.search_with_parent_context(query) + return await self.search_with_parent_context(query) else: return self.search(query) + async def retrieve_with_fusion(self, query: str, llm: Any, num_queries: int = 3, k: int = 5, return_parent: bool = True) -> List[Document]: + """ + 使用 RAG-Fusion 进行检索(多路查询改写 + 倒数排名融合)。 + + 核心原理: + 1. 多路查询改写: 利用 LLM 将原始查询改写成多个不同表述 + 2. 倒数排名融合: 对每个改写查询的结果进行 RRF 融合,避免单一检索结果主导 + + Args: + query: 原始搜索查询 + llm: 语言模型实例(用于查询改写) + num_queries: 生成的查询数量 + k: 返回的文档数量 + return_parent: 如果为 True 且使用父子块切分,返回父块 + 如果为 False,始终返回子块 + + Returns: + 经过融合后的相关文档列表 + """ + from langchain.retrievers.multi_query import MultiQueryRetriever + from langchain.retrievers import EnsembleRetriever + + if self.splitter_type == SplitterType.PARENT_CHILD and return_parent: + # 使用 ParentDocumentRetriever 作为基础检索器 + assert self.retriever is not None, "retriever 未初始化" + base_retriever = self.retriever + else: + # 使用向量存储作为基础检索器 + base_retriever = self.vector_store.as_langchain_vectorstore().as_retriever(search_kwargs={"k": k * 2}) + + # 创建多路查询检索器 + multi_query_retriever = MultiQueryRetriever.from_llm( + retriever=base_retriever, + llm=llm, + include_original=True + ) + + # 设置自定义提示词以生成指定数量的查询 + from langchain_core.prompts import PromptTemplate + multi_query_retriever.llm_chain.prompt = PromptTemplate.from_template( + "你是一个专业的查询改写助手。你的任务是将用户的问题改写成 {num_queries} 个不同的版本。\n" + "这些版本应该从不同的角度、使用不同的关键词来表达相同或相关的意图。\n\n" + "原始问题: {question}\n\n" + "请生成 {num_queries} 个不同版本的查询,每个版本一行。\n" + "确保每个版本都是独立、完整的查询语句。\n\n" + "生成 {num_queries} 个查询:" + ) + + # 修改调用参数以包含 num_queries + original_ainvoke = multi_query_retriever.llm_chain.ainvoke + async def new_ainvoke(input_dict): + input_dict["num_queries"] = num_queries + return await original_ainvoke(input_dict) + multi_query_retriever.llm_chain.ainvoke = new_ainvoke + + # 执行检索 + documents = await multi_query_retriever.ainvoke(query) + + # 去重并限制数量 + seen_content = set() + unique_documents = [] + for doc in documents: + content = doc.page_content + if content not in seen_content: + seen_content.add(content) + unique_documents.append(doc) + if len(unique_documents) >= k: + break + + logger.info(f"RAG-Fusion 检索完成: 原始 {len(documents)} 个结果,去重后 {len(unique_documents)} 个结果") + return unique_documents + def get_retriever(self) -> ParentDocumentRetriever: """ - Get the ParentDocumentRetriever instance directly. + 直接获取 ParentDocumentRetriever 实例。 - Useful for advanced use cases where you want to access the retriever - outside of IndexBuilder. + 适用于需要在 IndexBuilder 外部访问检索器的高级用例。 """ if self.splitter_type != SplitterType.PARENT_CHILD: raise RuntimeError( - "get_retriever() only available with PARENT_CHILD splitter. " - "Use search() or search_with_parent_context() for standard retrieval." + "get_retriever() 仅在 PARENT_CHILD 切分器下可用。" + "请使用 search() 或 search_with_parent_context() 进行标准检索。" ) + assert self.retriever is not None, "retriever 未初始化" return self.retriever - def get_child_splitter(self) -> "RecursiveCharacterTextSplitter": - """Get the child splitter for reconfiguration.""" + def get_child_splitter(self) -> Union["RecursiveCharacterTextSplitter", "SemanticChunker", "SemanticChunkerAdapter"]: + """获取子块切分器以便重新配置。""" if self.splitter_type != SplitterType.PARENT_CHILD: - return self.splitter + return self.splitter # type: ignore return self.child_splitter def get_parent_splitter(self) -> "RecursiveCharacterTextSplitter": - """Get the parent splitter for reconfiguration.""" + """获取父块切分器以便重新配置。""" if self.splitter_type != SplitterType.PARENT_CHILD: raise RuntimeError( - "Parent splitter only available with PARENT_CHILD splitter." + "父块切分器仅在 PARENT_CHILD 切分器下可用。" ) return self.parent_splitter def get_docstore(self) -> BaseStore: - """Get the document store for parent chunks.""" + """获取父块的文档存储。""" if self.splitter_type != SplitterType.PARENT_CHILD: raise RuntimeError( - "Docstore only available with PARENT_CHILD splitter." + "文档存储仅在 PARENT_CHILD 切分器下可用。" ) + assert self.docstore is not None, "docstore 未初始化" return self.docstore - def get_docstore_path(self) -> str: - """Get the document store path.""" + def get_docstore_path(self) -> Optional[str]: + """获取文档存储路径(已弃用,仅用于兼容性)。""" if self.splitter_type != SplitterType.PARENT_CHILD: raise RuntimeError( - "Docstore path only available with PARENT_CHILD splitter." + "文档存储路径仅在 PARENT_CHILD 切分器下可用。" ) - return self.docstore.persist_path + # PostgreSQL 存储没有 persist_path,返回 None + return None def close(self): - """Close resources.""" - if hasattr(self, "_docstore_conn") and self._docstore_conn: - import psycopg2 - conn = psycopg2.connect(self._docstore_conn) - conn.close() - logger.info("Closed PostgreSQL connection") + """关闭资源。""" + if self.docstore is not None and hasattr(self.docstore, "aclose"): + import asyncio + asyncio.get_event_loop().run_until_complete(self.docstore.aclose()) # type: ignore + logger.info("PostgreSQL 异步连接池已关闭") def __enter__(self): return self @@ -258,20 +385,8 @@ class IndexBuilder: return False -# RecursiveCharacterTextSplitter needs to be imported +# 需要导入 RecursiveCharacterTextSplitter from langchain_text_splitters import RecursiveCharacterTextSplitter -if __name__ == "__main__": - # Example usage - builder = IndexBuilder( - splitter_type=SplitterType.PARENT_CHILD, - parent_chunk_size=1000, - child_chunk_size=200, - docstore_path="./my_parent_docs", - ) - - print("Parent splitter:", builder.get_parent_splitter().chunk_size) - print("Child splitter:", builder.get_child_splitter().chunk_size) - print("Docstore path:", builder.get_docstore_path()) - print("Retriever:", builder.get_retriever()) +# 示例用法已移除,请参考文档 diff --git a/rag_indexer/cli.py b/rag_indexer/cli.py index b506ae3..63014b8 100755 --- a/rag_indexer/cli.py +++ b/rag_indexer/cli.py @@ -3,100 +3,85 @@ Command-line interface for the RAG index builder. """ import argparse +import asyncio import logging import sys -from builder import IndexBuilder -from splitters import SplitterType +from rag_indexer.builder import IndexBuilder +from rag_indexer.splitters import SplitterType logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) +# 基础配置 +COLLECTION_NAME = "rag_documents" +DB_URI = "postgresql://postgres:huang1998@115.190.121.151:5432/langgraph_db?sslmode=disable" -def main(): - parser = argparse.ArgumentParser(description="Offline RAG Index Builder") - parser.add_argument("--file", type=str, help="Path to file to index") - parser.add_argument("--dir", type=str, help="Path to directory to index") - parser.add_argument("--recursive", action="store_true", default=True, - help="Recursively process directories (default: True)") - parser.add_argument("--collection", type=str, default="rag_documents", - help="Qdrant collection name (default: rag_documents)") - parser.add_argument("--qdrant-url", type=str, - help="Qdrant server URL (default: http://127.0.0.1:6333)") - parser.add_argument("--splitter", type=str, - choices=["recursive", "semantic", "parent_child"], - default="recursive", - help="Text splitting strategy (default: recursive)") - parser.add_argument("--chunk-size", type=int, default=500, - help="Chunk size for recursive/parent splitter (default: 500)") - parser.add_argument("--chunk-overlap", type=int, default=50, - parser.add_argument("--docstore-path", type=str, - default=None, - help="Path to store parent documents for parent-child splitter (default: ./parent_docs or HERMES_HOME/parent_docs)") - parser.add_argument("--docstore-type", type=str, - choices=["local", "postgres"], - default="local", - help="Type of docstore: 'local' (default) or 'postgres' for PostgreSQL-backed storage") - parser.add_argument("--docstore-conn", type=str, - default=None, - help="PostgreSQL connection string for postgres docstore") +# 基础切分参数 +CHUNK_SIZE = 500 +CHUNK_OVERLAP = 50 - help="Chunk overlap (default: 50)") - parser.add_argument("--parent-size", type=int, default=1000, - help="Parent chunk size for parent-child splitter (default: 1000)") - parser.add_argument("--child-size", type=int, default=200, - help="Child chunk size for parent-child splitter (default: 200)") +# 父子块切分参数 +PARENT_CHUNK_SIZE = 1000 +CHILD_CHUNK_SIZE = 200 +PARENT_CHUNK_OVERLAP = 100 +CHILD_CHUNK_OVERLAP = 20 - args = parser.parse_args() +# 切分策略:basic(基础)、semantic(语义)、parent-child(父子块) +STRATEGY = "parent-child" - if not args.file and not args.dir: - print("Error: Either --file or --dir must be specified", file=sys.stderr) - parser.print_help() - sys.exit(1) +# 存储类型:postgres(PostgreSQL)、local(本地文件) +STORAGE_TYPE = "postgres" - splitter_map = { - "recursive": SplitterType.RECURSIVE, - "semantic": SplitterType.SEMANTIC, - "parent_child": SplitterType.PARENT_CHILD, - } - splitter_type = splitter_map[args.splitter] + +async def main(): + # 使用固定策略 + splitter_type = SplitterType.PARENT_CHILD + child_splitter_type = SplitterType.SEMANTIC splitter_kwargs = {} + if splitter_type == SplitterType.RECURSIVE: - splitter_kwargs["chunk_size"] = args.chunk_size - splitter_kwargs["chunk_overlap"] = args.chunk_overlap + splitter_kwargs["chunk_size"] = CHUNK_SIZE + splitter_kwargs["chunk_overlap"] = CHUNK_OVERLAP elif splitter_type == SplitterType.PARENT_CHILD: - splitter_kwargs["parent_chunk_size"] = args.parent_size - splitter_kwargs["child_chunk_size"] = args.child_size - splitter_kwargs["parent_chunk_overlap"] = args.chunk_overlap - splitter_kwargs["child_chunk_overlap"] = args.chunk_overlap // 2 - splitter_kwargs["docstore_path"] = args.docstore_path - splitter_kwargs["docstore_type"] = args.docstore_type - splitter_kwargs["docstore_conn_string"] = args.docstore_conn + splitter_kwargs["parent_chunk_size"] = PARENT_CHUNK_SIZE + splitter_kwargs["child_chunk_size"] = CHILD_CHUNK_SIZE + splitter_kwargs["parent_chunk_overlap"] = PARENT_CHUNK_OVERLAP + splitter_kwargs["child_chunk_overlap"] = CHILD_CHUNK_OVERLAP + splitter_kwargs["child_splitter_type"] = child_splitter_type + if STORAGE_TYPE == "postgres": + splitter_kwargs["docstore_conn_string"] = DB_URI + elif STORAGE_TYPE == "local": + splitter_kwargs["docstore_path"] = "./parent_docs" + else: + splitter_kwargs["docstore_conn_string"] = DB_URI builder = IndexBuilder( - collection_name=args.collection, - qdrant_url=args.qdrant_url, + collection_name=COLLECTION_NAME, splitter_type=splitter_type, **splitter_kwargs ) - try: - if args.file: - chunk_count = builder.build_from_file(args.file) - else: - chunk_count = builder.build_from_directory(args.dir, args.recursive) + is_file=False + path="data/corpus/" - print(f"Indexing completed. Total chunks indexed: {chunk_count}") + try: + if is_file: + chunk_count = await builder.build_from_file(path) + else: + chunk_count = await builder.build_from_directory(path, recursive=True) + + print(f"索引构建完成。共索引 {chunk_count} 个块") info = builder.get_collection_info() - print(f"Collection '{info['name']}' has {info['vectors_count']} vectors (dim={info['vector_size']})") + print(f"集合 '{info['name']}' 包含 {info['vectors_count']} 个向量(维度:{info['vector_size']})") except Exception as e: - logging.exception("Indexing failed") + logging.exception(f"索引构建失败:{e}") sys.exit(1) if __name__ == "__main__": - main() \ No newline at end of file + asyncio.run(main()) \ No newline at end of file diff --git a/rag_indexer/docstore_manager.py b/rag_indexer/docstore_manager.py deleted file mode 100644 index d2cce2c..0000000 --- a/rag_indexer/docstore_manager.py +++ /dev/null @@ -1,142 +0,0 @@ -""" -Document store manager for ParentDocumentRetriever. - -Supports both LocalFileStore (default) and custom PostgreSQL-backed stores. -""" - -import os -from typing import Optional -from langchain.storage import BaseStore, LocalFileStore - - -def get_docstore(persist_path: str = None) -> LocalFileStore: - """ - Create and return a document store for parent chunks. - - Args: - persist_path: Path to store parent documents. Defaults to ./parent_docs - or HERMES_HOME/parent_docs if set. - """ - if persist_path is None: - # Use HERMES_HOME if available, otherwise default to current directory - persist_path = os.getenv("HERMES_HOME") - if persist_path: - persist_path = os.path.join(persist_path, "parent_docs") - else: - persist_path = "./parent_docs" - - os.makedirs(persist_path, exist_ok=True) - return LocalFileStore(persist_path) - - -class PostgresDocStore(BaseStore): - """ - PostgreSQL-backed document store for parent chunks. - - This is an optional advanced feature. For most use cases, - LocalFileStore is sufficient and simpler. - """ - - def __init__(self, connection_string: str): - """ - Initialize PostgreSQL document store. - - Args: - connection_string: PostgreSQL connection URL - """ - import psycopg2 - from psycopg2 import sql - - self.conn_string = connection_string - self._conn = None - - # Create table if not exists - self._create_table() - - def _create_table(self): - """Create the parent documents table if not exists.""" - try: - self._conn = psycopg2.connect(self.conn_string) - cursor = self._conn.cursor() - cursor.execute(""" - CREATE TABLE IF NOT EXISTS parent_documents ( - key TEXT PRIMARY KEY, - value JSONB NOT NULL, - created_at TIMESTAMPTZ DEFAULT NOW() - ) - """) - self._conn.commit() - cursor.close() - except Exception as e: - raise RuntimeError(f"Failed to create PostgreSQL table: {e}") - - def get(self, key: str) -> Optional[dict]: - """Retrieve a document by key.""" - try: - self._ensure_connection() - cursor = self._conn.cursor() - cursor.execute("SELECT value FROM parent_documents WHERE key = %s", (key,)) - row = cursor.fetchone() - cursor.close() - if row: - import json - return json.loads(row[0]) - return None - except Exception as e: - raise RuntimeError(f"Failed to retrieve document: {e}") - - def set(self, key: str, value: dict) -> None: - """Store a document.""" - try: - self._ensure_connection() - cursor = self._conn.cursor() - # Upsert - insert_query = sql.SQL( - "INSERT INTO parent_documents (key, value) VALUES (%s, %s)" - ) - update_query = sql.SQL( - "UPDATE parent_documents SET value = %s WHERE key = %s" - ) - cursor.execute(insert_query, (key, json.dumps(value))) - try: - cursor.execute(update_query, (key, json.dumps(value))) - except psycopg2.IntegrityError: - pass # Key exists, ignore - self._conn.commit() - cursor.close() - except Exception as e: - raise RuntimeError(f"Failed to store document: {e}") - - def _ensure_connection(self): - """Ensure we have an open connection.""" - if self._conn is None or self._conn.closed: - self._conn = psycopg2.connect(self.conn_string) - - def close(self): - """Close the connection.""" - if self._conn and not self._conn.closed: - self._conn.close() - - -# Factory function for creating custom docstores -# Returns a tuple: (BaseStore instance, connection_string or None) -def create_docstore( - store_type: str = "local", - persist_path: str = None, - connection_string: str = None -) -> tuple: - """ - Factory function to create different types of document stores. - - Args: - store_type: "local" (default), "postgres" - persist_path: Path for local file store - connection_string: PostgreSQL connection string - - Returns: - Tuple of (BaseStore instance, connection_string or None) - """ - if store_type == "postgres" and connection_string: - return (PostgresDocStore(connection_string), connection_string) - else: - return (get_docstore(persist_path), None) diff --git a/rag_indexer/embedders.py b/rag_indexer/embedders.py index 80e6adf..3f1be8a 100644 --- a/rag_indexer/embedders.py +++ b/rag_indexer/embedders.py @@ -1,16 +1,17 @@ """ -Embedding model wrapper for llama.cpp service. +嵌入模型包装器,用于 llama.cpp 服务。 """ import os +import httpx from typing import List, Optional from urllib.parse import urljoin -from langchain_openai import OpenAIEmbeddings +from langchain_core.embeddings import Embeddings class LlamaCppEmbedder: - """Wrapper for llama.cpp embedding service via OpenAI-compatible API.""" + """通过 OpenAI 兼容 API 封装 llama.cpp 嵌入服务。""" def __init__( self, @@ -22,47 +23,66 @@ class LlamaCppEmbedder: self.api_key = api_key or os.getenv("LLAMACPP_API_KEY", "") self.model = model - # Ensure URL ends with /v1 - self.base_url = urljoin(self.base_url.rstrip("/") + "/", "v1") - - def as_langchain_embeddings(self) -> OpenAIEmbeddings: - """Create LangChain OpenAIEmbeddings instance.""" - return OpenAIEmbeddings( - openai_api_base=self.base_url, - openai_api_key=self.api_key, - model=self.model, - ) + def as_langchain_embeddings(self) -> Embeddings: + """创建 LangChain 兼容的嵌入实例。""" + return _LlamaCppLangchainAdapter(self) def embed_documents(self, texts: List[str]) -> List[List[float]]: - """Embed a list of documents.""" - emb = self.as_langchain_embeddings() - return emb.embed_documents(texts) + """嵌入一批文档。""" + return self._call_embedding_api(texts) def embed_query(self, text: str) -> List[float]: - """Embed a single query.""" - emb = self.as_langchain_embeddings() - return emb.embed_query(text) + """嵌入单个查询。""" + return self._call_embedding_api([text])[0] def get_embedding_dimension(self) -> int: - """Get embedding dimension by embedding a test string.""" + """通过嵌入测试字符串获取嵌入维度。""" test_embedding = self.embed_query("test") return len(test_embedding) + def _call_embedding_api(self, texts: List[str]) -> List[List[float]]: + """直接调用 llama.cpp 嵌入 API。""" + base = self.base_url.rstrip("/") + if not base.endswith("/v1"): + base = base + "/v1" -class MockEmbedder: - """Mock embedder for testing without a real service.""" + headers = {"Content-Type": "application/json"} + if self.api_key: + headers["Authorization"] = f"Bearer {self.api_key}" - def __init__(self, dimension: int = 768): - self.dimension = dimension + payload = { + "input": texts, + "model": self.model, + } - def as_langchain_embeddings(self) -> OpenAIEmbeddings: - raise NotImplementedError("MockEmbedder cannot be used as LangChain embeddings") + with httpx.Client(timeout=120) as client: + response = client.post( + f"{base}/embeddings", + headers=headers, + json=payload, + ) + response.raise_for_status() + data = response.json() + + # 处理不同响应格式 + if isinstance(data, list): + # llama.cpp 直接返回列表 + return [item["embedding"] for item in data] + elif isinstance(data, dict) and "data" in data: + # OpenAI 标准格式 + return [item["embedding"] for item in sorted(data["data"], key=lambda x: x["index"])] + else: + raise ValueError(f"未知的嵌入 API 响应格式: {data}") + + +class _LlamaCppLangchainAdapter(Embeddings): + """将 LlamaCppEmbedder 适配为 LangChain Embeddings 接口。""" + + def __init__(self, embedder: LlamaCppEmbedder): + self._embedder = embedder def embed_documents(self, texts: List[str]) -> List[List[float]]: - return [[0.0] * self.dimension for _ in texts] + return self._embedder.embed_documents(texts) def embed_query(self, text: str) -> List[float]: - return [0.0] * self.dimension - - def get_embedding_dimension(self) -> int: - return self.dimension \ No newline at end of file + return self._embedder.embed_query(text) diff --git a/rag_indexer/example_parent_child.py b/rag_indexer/example_parent_child.py deleted file mode 100644 index 19db145..0000000 --- a/rag_indexer/example_parent_child.py +++ /dev/null @@ -1,124 +0,0 @@ -""" -Example demonstrating ParentDocumentRetriever usage. - -This script shows how to: -1. Build an index with parent-child chunking -2. Search with child chunks (fast, precise) -3. Search with parent context (large context) -4. Access the retriever directly for advanced use cases -""" - -import logging -from pathlib import Path - -# Configure logging -logging.basicConfig( - level=logging.INFO, - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" -) - -from builder import IndexBuilder -from splitters import SplitterType - - -def main(): - print("=" * 70) - print("ParentDocumentRetriever Example") - print("=" * 70) - - # Step 1: Create IndexBuilder with parent-child splitting - print("\n1. Creating IndexBuilder with parent-child splitting...") - builder = IndexBuilder( - collection_name="parent_child_demo", - splitter_type=SplitterType.PARENT_CHILD, - parent_chunk_size=1000, # Parent chunks: larger context - child_chunk_size=200, # Child chunks: smaller for precision - docstore_path="./my_parent_docs", # Where to store parent chunks - search_k=5, # Number of child chunks to retrieve - ) - - print(f" Parent splitter: chunk_size={builder.get_parent_splitter().chunk_size}") - print(f" Child splitter: chunk_size={builder.get_child_splitter().chunk_size}") - print(f" Docstore path: {builder.get_docstore_path()}") - print(f" Search k: {builder.retriever.search_kwargs['k']}") - - # Step 2: Build index from a sample file - print("\n2. Building index from sample file...") - - # Create a test document - test_content = """ - This is a test document for demonstrating ParentDocumentRetriever. - - Parent chunks contain larger portions of text (1000 characters), - while child chunks are smaller (200 characters) for precise retrieval. - - When you search with ParentDocumentRetriever: - - It first retrieves relevant child chunks - - Then replaces them with their corresponding parent chunks - - This gives you large context while maintaining precision - - Example search queries: - - "ParentDocumentRetriever" - - "child chunks" - - "large context" - - "precise retrieval" - """ - - test_file = Path("./test_document.txt") - test_file.write_text(test_content) - - chunk_count = builder.build_from_file(str(test_file)) - print(f" Indexed {chunk_count} documents") - - # Step 3: Search with child chunks (fast, precise) - print("\n3. Searching with child chunks (fast, precise)...") - child_results = builder.search("ParentDocumentRetriever", k=3) - print(f" Found {len(child_results)} child chunks:") - for i, doc in enumerate(child_results, 1): - print(f" [{i}] {doc.page_content[:100]}...") - - # Step 4: Search with parent context (large context) - print("\n4. Searching with parent context (large context)...") - parent_results = builder.search_with_parent_context("ParentDocumentRetriever", k=3) - print(f" Found {len(parent_results)} parent chunks:") - for i, doc in enumerate(parent_results, 1): - print(f" [{i}] {doc.page_content[:150]}...") - - # Step 5: Compare results - print("\n5. Comparing child vs parent results...") - print(f" Child chunks total length: {sum(len(d.page_content) for d in child_results)}") - print(f" Parent chunks total length: {sum(len(d.page_content) for d in parent_results)}") - print(f" Ratio: parent/child = {sum(len(d.page_content) for d in parent_results) / max(sum(len(d.page_content) for d in child_results), 1):.2f}x larger") - - # Step 6: Access retriever directly - print("\n6. Accessing retriever directly...") - retriever = builder.get_retriever() - print(f" Retriever type: {type(retriever).__name__}") - print(f" Vectorstore: {retriever.vectorstore}") - print(f" Docstore: {retriever.docstore}") - - # Step 7: Unified retrieval interface - print("\n7. Using unified retrieval interface...") - unified_results = builder.retrieve("ParentDocumentRetriever", return_parent=True) - print(f" Retrieved {len(unified_results)} documents (with parent context)") - - # Step 8: Collection info - print("\n8. Collection info...") - info = builder.get_collection_info() - print(f" Collection: {info['name']}") - print(f" Vectors: {info['vectors_count']}") - print(f" Vector size: {info['vector_size']}") - - # Cleanup - print("\n9. Cleaning up...") - builder.close() - - print("\n" + "=" * 70) - print("Example completed successfully!") - print("=" * 70) - - return builder - - -if __name__ == "__main__": - builder = main() diff --git a/rag_indexer/loaders.py b/rag_indexer/loaders.py index b896015..d0c16c4 100644 --- a/rag_indexer/loaders.py +++ b/rag_indexer/loaders.py @@ -1,10 +1,10 @@ """ -Document loaders using unstructured library. +文档加载器,使用 unstructured 库解析文档。 """ import logging from pathlib import Path -from typing import List, Union +from typing import Any, Dict, List, Mapping, Optional, Union from langchain_core.documents import Document from unstructured.partition.auto import partition @@ -13,33 +13,74 @@ logger = logging.getLogger(__name__) class DocumentLoader: - """Load documents from various file formats.""" + """从各种文件格式加载文档。""" - SUPPORTED_EXTENSIONS = {".pdf", ".docx", ".doc", ".txt", ".md", ".html", ".pptx", ".xlsx"} + SUPPORTED_EXTENSIONS = {".pdf", ".docx", ".doc", ".txt", ".md", ".html", ".pptx", ".xlsx", ".json"} - def __init__(self, extract_images: bool = False): + def __init__( + self, + extract_images: bool = False, + strategy: str = "auto", + ocr_languages: Optional[List[str]] = None, + languages: Optional[List[str]] = None, + include_page_breaks: bool = False, + pdf_infer_table_structure: bool = True, + partition_kwargs: Optional[Dict[str, Any]] = None, + ): """ Args: - extract_images: Whether to extract images from documents (requires additional dependencies) + extract_images: 是否提取 PDF 中的图片 + strategy: 解析策略 (auto, fast, hi_res, ocr_only) + ocr_languages: OCR 语言列表,如 ['chi_sim', 'eng'] + languages: 文档主语言,如 ['zh'] + include_page_breaks: 是否包含分页符 + pdf_infer_table_structure: 是否识别表格结构 (需 hi_res 策略) + partition_kwargs: 额外的 partition 参数字典(高级定制) """ + import os + os.environ["UNSTRUCTURED_LANGUAGE_CHECKS"] = "false" self.extract_images = extract_images + self.strategy = strategy + self.ocr_languages = ocr_languages or ["chi_sim", "eng"] + self.languages = languages or ["zh"] + self.include_page_breaks = include_page_breaks + self.pdf_infer_table_structure = pdf_infer_table_structure + self.partition_kwargs = partition_kwargs or {} def load_file(self, file_path: Union[str, Path]) -> List[Document]: - """Load a single file into LangChain Document objects.""" + """将单个文件加载为 LangChain Document 对象。""" file_path = Path(file_path).resolve() if not file_path.exists(): - raise FileNotFoundError(f"File not found: {file_path}") + raise FileNotFoundError(f"文件不存在: {file_path}") suffix = file_path.suffix.lower() if suffix not in self.SUPPORTED_EXTENSIONS: raise ValueError( - f"Unsupported file extension: {suffix}. Supported: {self.SUPPORTED_EXTENSIONS}" + f"不支持的文件扩展名: {suffix}。支持的格式: {self.SUPPORTED_EXTENSIONS}" ) - # Parse with unstructured + # 根据文件类型动态调整参数 + extra_kwargs = {} + if suffix == ".pdf": + extra_kwargs["strategy"] = self.strategy + extra_kwargs["ocr_languages"] = self.ocr_languages + extra_kwargs["extract_images_in_pdf"] = self.extract_images + extra_kwargs["pdf_infer_table_structure"] = self.pdf_infer_table_structure + + # languages 参数适用于所有文件类型 + if self.languages: + extra_kwargs["languages"] = self.languages + + extra_kwargs["include_page_breaks"] = self.include_page_breaks + + # 合并用户自定义的额外参数(优先级最高) + extra_kwargs.update(self.partition_kwargs) + + # 使用 unstructured 解析 elements = partition( filename=str(file_path), - extract_images_in_pdf=self.extract_images, + + **extra_kwargs ) documents = [] @@ -48,23 +89,17 @@ class DocumentLoader: if not text or not text.strip(): continue - # Base metadata + # 基础元数据 metadata = { "source": str(file_path), "file_name": file_path.name, "file_type": suffix, } - - # Merge element-specific metadata without overwriting base fields - elem_meta = getattr(elem, "metadata", {}) or {} - for key, value in elem_meta.items(): - if value and key not in metadata: - metadata[key] = value - + documents.append(Document(page_content=text, metadata=metadata)) if not documents: - logger.warning("No text content extracted from %s", file_path) + logger.warning("未从 %s 提取到文本内容", file_path) return [] return documents @@ -72,10 +107,10 @@ class DocumentLoader: def load_directory( self, directory_path: Union[str, Path], recursive: bool = True ) -> List[Document]: - """Load all supported files from a directory.""" + """从目录加载所有支持的文件。""" directory_path = Path(directory_path).resolve() if not directory_path.is_dir(): - raise NotADirectoryError(f"Not a directory: {directory_path}") + raise NotADirectoryError(f"不是目录: {directory_path}") all_documents = [] pattern = "**/*" if recursive else "*" @@ -86,6 +121,6 @@ class DocumentLoader: docs = self.load_file(file_path) all_documents.extend(docs) except Exception as e: - logger.error("Failed to load %s: %s", file_path, e) + logger.error("加载 %s 失败: %s", file_path, e) return all_documents \ No newline at end of file diff --git a/rag_indexer/splitters.py b/rag_indexer/splitters.py index 718f969..45874d3 100644 --- a/rag_indexer/splitters.py +++ b/rag_indexer/splitters.py @@ -1,12 +1,12 @@ """ -Text splitters for chunking documents. +文本切分器,用于将文档切分成块。 """ from enum import Enum from typing import List, Optional from langchain_core.documents import Document -from langchain_text_splitters import RecursiveCharacterTextSplitter +from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter from langchain_experimental.text_splitter import SemanticChunker @@ -17,7 +17,7 @@ class SplitterType(str, Enum): def get_splitter(splitter_type: SplitterType, **kwargs): - """Factory function to create a text splitter.""" + """工厂函数,创建文本切分器。""" if splitter_type == SplitterType.RECURSIVE: chunk_size = kwargs.get("chunk_size", 500) chunk_overlap = kwargs.get("chunk_overlap", 50) @@ -27,19 +27,31 @@ def get_splitter(splitter_type: SplitterType, **kwargs): separators=["\n\n", "\n", "。", "!", "?", " ", ""], ) elif splitter_type == SplitterType.SEMANTIC: - # Requires embeddings for semantic splitting - embeddings = kwargs.get("embeddings") + embeddings = kwargs.pop("embeddings", None) if embeddings is None: - raise ValueError("Semantic splitter requires 'embeddings' parameter") - return SemanticChunker(embeddings=embeddings) + raise ValueError("语义切分器需要提供 'embeddings' 参数") + return SemanticChunkerAdapter(embeddings=embeddings, **kwargs) else: - raise ValueError(f"Unsupported splitter type: {splitter_type}") + raise ValueError(f"不支持的切分器类型: {splitter_type}") + + +class SemanticChunkerAdapter(TextSplitter): + """将 SemanticChunker 适配为 TextSplitter 接口。""" + + def __init__(self, embeddings, **kwargs): + super().__init__(**kwargs) + chunk_size = kwargs.pop("chunk_size", None) + chunk_overlap = kwargs.pop("chunk_overlap", None) + self._chunker = SemanticChunker(embeddings=embeddings, **kwargs) + + def split_text(self, text: str) -> List[str]: + return self._chunker.split_text(text) class ParentChildSplitter: """ - Splits documents into parent (large) and child (small) chunks. - Child chunks are indexed for retrieval, parent chunks are stored for context. + 将文档切分为父块(大块)和子块(小块)。 + 子块用于索引检索,父块用于存储上下文。 """ def __init__( @@ -60,12 +72,12 @@ class ParentChildSplitter: def split_documents(self, documents: List[Document]) -> tuple[List[Document], List[Document]]: """ - Returns: - (parent_chunks, child_chunks) + 返回: + (父块列表, 子块列表) """ parent_chunks = self.parent_splitter.split_documents(documents) child_chunks = self.child_splitter.split_documents(documents) - # Link child chunks to parent IDs (optional metadata) - # In a real implementation, you'd map each child to a parent chunk ID. + # 将子块与父块 ID 关联(可选元数据) + # 在实际实现中,需要将每个子块映射到对应的父块 ID。 return parent_chunks, child_chunks \ No newline at end of file diff --git a/rag_indexer/store/__init__.py b/rag_indexer/store/__init__.py new file mode 100644 index 0000000..a1e561e --- /dev/null +++ b/rag_indexer/store/__init__.py @@ -0,0 +1,31 @@ +""" +文档存储模块 - 用于 ParentDocumentRetriever 的父文档存储。 + +提供 PostgreSQL 存储后端: +- PostgresDocStore: PostgreSQL 数据库存储(生产环境) + +示例用法: + >>> from rag_indexer.store import create_docstore + + >>> # 创建 PostgreSQL 存储 + >>> store, conn = create_docstore( + ... connection_string="postgresql://user:pass@host:5432/db", + ... table_name="parent_docs" + ... ) +""" + + +from .postgres import PostgresDocStore +from .factory import create_docstore, get_docstore_uri, DEFAULT_DB_URI + +__version__ = "2.0.0" + +__all__ = [ + # 具体实现 + "PostgresDocStore", + + # 工厂函数 + "create_docstore", + "get_docstore_uri", + "DEFAULT_DB_URI", +] diff --git a/rag_indexer/store/factory.py b/rag_indexer/store/factory.py new file mode 100644 index 0000000..2388f8f --- /dev/null +++ b/rag_indexer/store/factory.py @@ -0,0 +1,73 @@ +""" +文档存储工厂 - 创建不同类型的存储实例。 + +提供统一的接口来创建本地文件存储或 PostgreSQL 存储。 +""" + +import os +import logging +from typing import Optional, Tuple + +from langchain_core.stores import BaseStore +from .postgres import PostgresDocStore + +logger = logging.getLogger(__name__) + +# 默认连接字符串(从环境变量读取) +DEFAULT_DB_URI = os.getenv( + "DB_URI", + "postgresql://postgres:huang1998@ai-postgres:5432/langgraph_db?sslmode=disable" +) + + +def get_docstore_uri() -> str: + """获取 docstore 专用的数据库连接字符串(可与主库相同)""" + return os.getenv("DOCSTORE_URI", DEFAULT_DB_URI) + + +def create_docstore( + store_type: str = "postgres", + connection_string: Optional[str] = None, + table_name: str = "parent_documents", + pool_config: Optional[dict] = None, + max_concurrency: Optional[int] = None +) -> Tuple[BaseStore, Optional[str]]: + """ + 工厂函数,创建 PostgreSQL 文档存储。 + + Args: + store_type: 存储类型,目前仅支持 "postgres"(默认) + connection_string: PostgreSQL 连接字符串 + table_name: PostgreSQL 表名(默认:parent_documents) + pool_config: 连接池配置 + max_concurrency: 最大并发操作数,如果为 None 则不限制 + + Returns: + 元组 (存储实例, 连接字符串) + + Raises: + ValueError: 不支持的存储类型 + ImportError: 缺少必要的依赖 + + Example: + >>> # 创建 PostgreSQL 存储 + >>> store, conn = create_docstore( + ... connection_string="postgresql://user:pass@host:5432/db", + ... table_name="parent_docs", + ... max_concurrency=10 + ... ) + """ + store_type = store_type.lower() + + if store_type == "postgres": + conn_str = connection_string or get_docstore_uri() + store = PostgresDocStore( + connection_string=conn_str, + table_name=table_name, + pool_config=pool_config, + max_concurrency=max_concurrency + ) + return store, conn_str + + else: + raise ValueError(f"不支持的存储类型: {store_type}。目前仅支持: postgres") \ No newline at end of file diff --git a/rag_indexer/store/postgres.py b/rag_indexer/store/postgres.py new file mode 100644 index 0000000..69ef4e3 --- /dev/null +++ b/rag_indexer/store/postgres.py @@ -0,0 +1,249 @@ +""" +异步 PostgreSQL 存储实现 - 用于生产环境。 + +使用 asyncpg 实现真正的异步 PostgreSQL 文档存储,支持高并发访问。 +""" + +from __future__ import annotations + +import asyncio +import json +import logging +from typing import List, Dict, Any, Optional, Iterator, Tuple, Sequence, cast + +from langchain_core.documents import Document +from langchain_core.stores import BaseStore + +import asyncpg + +logger = logging.getLogger(__name__) + + +class PostgresDocStore(BaseStore[str, Any]): + """ + 异步 PostgreSQL 文档存储实现。 + + 使用 asyncpg 作为异步 PostgreSQL 客户端,支持: + - 真正的异步操作 + - 连接池管理 + - 自动表创建 + - 批量操作(amget/amset/amdelete) + - JSONB 数据存储 + - 并发控制 + + 适用于生产环境,提供高性能的异步数据持久化。 + + Attributes: + dsn: PostgreSQL 连接字符串 + table_name: 存储表名,默认为 "parent_documents" + _pool: asyncpg 连接池实例 + _semaphore: 控制并发数的信号量(可选) + """ + + def __init__( + self, + connection_string: str, + table_name: str = "parent_documents", + pool_config: Optional[Dict[str, Any]] = None, + max_concurrency: Optional[int] = None + ): + """ + 初始化异步 PostgreSQL 文档存储。 + + Args: + connection_string: PostgreSQL 连接 URL,格式: + "postgresql://user:password@host:port/database?sslmode=disable" + table_name: 存储表名,默认为 "parent_documents" + pool_config: 连接池配置字典,包含: + - min_size: 最小连接数(默认 2) + - max_size: 最大连接数(默认 10) + max_concurrency: 最大并发操作数,如果为 None 则不限制 + + Raises: + ImportError: 未安装 asyncpg 时抛出 + + Example: + >>> store = PostgresDocStore( + ... "postgresql://user:pass@localhost:5432/mydb", + ... table_name="parent_docs", + ... pool_config={"min_size": 5, "max_size": 20}, + ... max_concurrency=10 + ... ) + """ + + + self.dsn = connection_string + self.table_name = table_name + self._pool: Optional["asyncpg.Pool"] = None + self._pool_config = pool_config or {} + + # 并发控制信号量 + self._semaphore = None + if max_concurrency is not None and max_concurrency > 0: + self._semaphore = asyncio.Semaphore(max_concurrency) + + # 注意:连接池的异步初始化延迟到第一次使用时 + # 表结构创建也延迟到第一次操作时 + + async def _get_pool(self): + """获取或创建 asyncpg 连接池。""" + if self._pool is None: + import asyncpg + min_size = self._pool_config.get("min_size", 2) + max_size = self._pool_config.get("max_size", 10) + + try: + self._pool = await asyncpg.create_pool( + dsn=self.dsn, + min_size=min_size, + max_size=max_size + ) + logger.info(f"PostgreSQL 异步连接池已创建: {self.table_name}") + + # 初始化表结构 + await self._create_table() + except Exception as e: + raise RuntimeError(f"PostgreSQL 异步连接池创建失败: {e}") + + return self._pool + + async def _create_table(self): + """创建存储表(如果不存在)。""" + pool = await self._get_pool() + async with pool.acquire() as conn: + async with conn.transaction(): + await conn.execute(f""" + CREATE TABLE IF NOT EXISTS {self.table_name} ( + key TEXT PRIMARY KEY, + value JSONB NOT NULL, + created_at TIMESTAMPTZ DEFAULT NOW() + ) + """) + logger.info(f"表 {self.table_name} 已就绪") + + async def _with_concurrency_control(self, coro): + """使用信号量控制并发执行。""" + if self._semaphore is None: + return await coro + async with self._semaphore: + return await coro + + # --- 同步方法(保持兼容性,但功能有限)--- + + def mget(self, keys: Sequence[str]) -> List[Optional[Any]]: + """不支持同步操作,请使用异步 amget 方法。""" + raise NotImplementedError("不支持同步操作,请使用异步 amget 方法。") + + def mset(self, key_value_pairs: Sequence[Tuple[str, Any]]) -> None: + """不支持同步操作,请使用异步 amset 方法。""" + raise NotImplementedError("不支持同步操作,请使用异步 amset 方法。") + + def mdelete(self, keys: Sequence[str]) -> None: + """不支持同步操作,请使用异步 amdelete 方法。""" + raise NotImplementedError("不支持同步操作,请使用异步 amdelete 方法。") + + def yield_keys(self, *, prefix: str | None = None) -> Iterator[str]: + """不支持同步操作,请使用异步 ayield_keys 方法。""" + raise NotImplementedError("不支持同步操作,请使用异步 ayield_keys 方法。") + + # --- 异步方法(真正的实现)--- + + async def amget(self, keys: Sequence[str]) -> List[Optional[Any]]: + """异步批量获取文档。""" + if not keys: + return [] + + async def _amget(): + pool = await self._get_pool() + async with pool.acquire() as conn: + rows = await conn.fetch( + f"SELECT key, value FROM {self.table_name} WHERE key = ANY($1)", + keys + ) + result_map = {} + for row in rows: + val = row['value'] + if isinstance(val, str): + val = json.loads(val) + if isinstance(val, dict) and 'page_content' in val: + result_map[row['key']] = Document(**val) + else: + result_map[row['key']] = val + return [result_map.get(key) for key in keys] + + return await self._with_concurrency_control(_amget()) + + async def amset(self, key_value_pairs: Sequence[Tuple[str, Any]]) -> None: + """异步批量设置文档。""" + if not key_value_pairs: + return + + async def _amset(): + pool = await self._get_pool() + async with pool.acquire() as conn: + async with conn.transaction(): + await conn.executemany( + f""" + INSERT INTO {self.table_name} (key, value) + VALUES ($1, $2) + ON CONFLICT (key) DO UPDATE SET value = EXCLUDED.value + """, + [ + (k, json.dumps(v.dict() if isinstance(v, Document) else v, ensure_ascii=False)) + for k, v in key_value_pairs + ] + ) + logger.debug(f"已异步批量设置 {len(key_value_pairs)} 个文档") + + await self._with_concurrency_control(_amset()) + + async def amdelete(self, keys: Sequence[str]) -> None: + """异步批量删除文档。""" + if not keys: + return + + async def _amdelete(): + pool = await self._get_pool() + async with pool.acquire() as conn: + async with conn.transaction(): + await conn.execute( + f"DELETE FROM {self.table_name} WHERE key = ANY($1)", + keys + ) + logger.debug(f"已异步批量删除 {len(keys)} 个文档") + + await self._with_concurrency_control(_amdelete()) + + async def ayield_keys(self, *, prefix: str | None = None) -> Iterator[str]: + """异步迭代所有键。 + + 注意:这是一个异步生成器,需要使用 async for 迭代。 + """ + pool = await self._get_pool() + async with pool.acquire() as conn: + if prefix: + rows = await conn.fetch( + f"SELECT key FROM {self.table_name} WHERE key LIKE $1 ORDER BY key", + f"{prefix}%" + ) + else: + rows = await conn.fetch( + f"SELECT key FROM {self.table_name} ORDER BY key" + ) + + for row in rows: + yield row['key'] + + async def aclose(self) -> None: + """异步关闭连接池,释放资源。""" + if self._pool: + await self._pool.close() + self._pool = None + logger.info("PostgreSQL 异步连接池已关闭") + + def close(self) -> None: + """同步关闭连接池(功能有限)。 + + 注意:在异步环境中,请使用 aclose 方法。 + """ + pass \ No newline at end of file diff --git a/rag_indexer/vector_store.py b/rag_indexer/vector_store.py index 87bd6bf..4f2e5c6 100644 --- a/rag_indexer/vector_store.py +++ b/rag_indexer/vector_store.py @@ -1,5 +1,5 @@ """ -Qdrant vector store wrapper. +Qdrant 向量数据库包装器。 """ import logging @@ -16,67 +16,85 @@ from .embedders import LlamaCppEmbedder logger = logging.getLogger(__name__) +QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333") +QDRANT_API_KEY = os.getenv("QDRANT_API_KEY") + class QdrantVectorStore: - """Wrapper for Qdrant vector database operations.""" + """Qdrant 向量数据库操作包装器。""" def __init__( self, collection_name: str, embeddings: Optional[Any] = None, - qdrant_url: Optional[str] = None, - api_key: Optional[str] = None, ): self.collection_name = collection_name - self.qdrant_url = qdrant_url or os.getenv("QDRANT_URL", "http://127.0.0.1:6333") - self.api_key = api_key + self._client: Optional[QdrantClient] = None - # Embeddings + # 嵌入模型 if embeddings is None: embedder = LlamaCppEmbedder() self.embeddings = embedder.as_langchain_embeddings() else: self.embeddings = embeddings - # Qdrant client - self.client = QdrantClient(url=self.qdrant_url, api_key=self.api_key) + # 先创建集合 + self.create_collection() - # LangChain vector store + # LangChain 向量存储 self.vector_store = LangchainQdrantVS( - client=self.client, + client=self.get_client(), collection_name=self.collection_name, - embeddings=self.embeddings, + embedding=self.embeddings, ) + def get_client(self) -> QdrantClient: + """懒加载客户端,每次获取时确保连接可用。""" + if self._client is None: + self._client = QdrantClient( + url=QDRANT_URL, + api_key=QDRANT_API_KEY, + timeout=120, + http2=False, + ) + return self._client + + def refresh_client(self): + """关闭旧连接,创建新连接。""" + if self._client is not None: + self._client.close() + self._client = None + def create_collection(self, vector_size: Optional[int] = None, force_recreate: bool = False): - """Create collection with appropriate vector size.""" + """创建集合,设置合适的向量维度。""" if vector_size is None: embedder = LlamaCppEmbedder() vector_size = embedder.get_embedding_dimension() - collections = self.client.get_collections().collections + client = self.get_client() + collections = client.get_collections().collections exists = any(c.name == self.collection_name for c in collections) if exists and force_recreate: - self.client.delete_collection(self.collection_name) + client.delete_collection(self.collection_name) exists = False if not exists: - self.client.create_collection( + client.create_collection( collection_name=self.collection_name, vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE), ) - logger.info("Collection '%s' created (dim=%d)", self.collection_name, vector_size) + logger.info("集合 '%s' 已创建(维度=%d)", self.collection_name, vector_size) else: - logger.info("Collection '%s' already exists", self.collection_name) + logger.info("集合 '%s' 已存在", self.collection_name) def add_documents(self, documents: List[Document], batch_size: int = 100): - """Add documents to vector store.""" + """将文档添加到向量数据库。""" if not documents: return [] self.create_collection() ids = self.vector_store.add_documents(documents, batch_size=batch_size) - logger.info("Added %d documents to '%s'", len(ids), self.collection_name) + logger.info("已向 '%s' 添加 %d 个文档", self.collection_name, len(ids)) return ids def similarity_search(self, query: str, k: int = 5) -> List[Document]: @@ -86,16 +104,21 @@ class QdrantVectorStore: return self.vector_store.similarity_search_with_score(query, k=k) def delete_collection(self): - self.client.delete_collection(self.collection_name) - logger.info("Collection '%s' deleted", self.collection_name) + self.get_client().delete_collection(self.collection_name) + logger.info("集合 '%s' 已删除", self.collection_name) def get_collection_info(self) -> Dict[str, Any]: - info = self.client.get_collection(self.collection_name) + info = self.get_client().get_collection(self.collection_name) + vectors_config = info.config.params.vectors + if isinstance(vectors_config, dict): + vector_size = next(iter(vectors_config.values())).size + else: + vector_size = vectors_config.size return { - "name": info.name, - "vectors_count": info.vectors_count, + "name": self.collection_name, + "vectors_count": info.points_count or 0, "status": info.status, - "vector_size": info.config.params.vectors.size, + "vector_size": vector_size, } def as_langchain_vectorstore(self): @@ -107,4 +130,4 @@ class QdrantVectorStore: def get_qdrant_client(self): """返回原生 Qdrant 客户端(如需手动管理 collection)""" - return self.client \ No newline at end of file + return self.get_client() \ No newline at end of file diff --git a/requirement.txt b/requirement.txt index 699f33b..874a5b7 100644 --- a/requirement.txt +++ b/requirement.txt @@ -49,7 +49,8 @@ python-dotenv==1.2.2 typing-extensions==4.15.0 unstructured>=0.0.1 - +spacy>=3.7.0 +langchain_experimental>=0.0.1 # ============================================================================ # 注意: # 1. 此文件包含项目直接依赖的精确版本