diff --git a/app/__init__.py b/app/__init__.py index ca80779..0df91f3 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -3,6 +3,6 @@ AI Agent 应用模块 """ from .agent import AIAgentService -from .tools import AVAILABLE_TOOLS, TOOLS_BY_NAME +from .graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME __all__ = ["AIAgentService", "AVAILABLE_TOOLS", "TOOLS_BY_NAME"] diff --git a/app/agent.py b/app/agent.py index fe191bc..29b91b5 100644 --- a/app/agent.py +++ b/app/agent.py @@ -31,7 +31,7 @@ except ImportError: # 本地模块 from app.graph_builder import GraphBuilder, GraphContext -from app.tools import AVAILABLE_TOOLS, TOOLS_BY_NAME +from app.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME try: from app.rag import RAGPipeline from app.rag.tools import RAGTool diff --git a/app/tools.py b/app/graph_tools.py similarity index 100% rename from app/tools.py rename to app/graph_tools.py diff --git a/app/rag/__init__.py b/app/rag/__init__.py index 623bb8f..8b4868f 100644 --- a/app/rag/__init__.py +++ b/app/rag/__init__.py @@ -2,52 +2,69 @@ RAG 检索与生成模块 提供在线检索与生成功能,包括: -- 基础向量检索 -- 重排序 -- RAG-Fusion -- Agentic RAG +- 基础向量检索(稠密向量 / 混合检索) +- 重排序(Cross-Encoder) +- 多路查询改写(Multi-Query) +- RRF 融合(Reciprocal Rank Fusion) +- 完整的 RAG 流水线 +- Agent 工具封装 + +固定流水线: + 用户查询 → 多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档 示例用法: - >>> from app.rag import RAGPipeline, search_knowledge_base - >>> from rag_core import LlamaCppEmbedder - >>> - >>> embeddings = LlamaCppEmbedder() - >>> pipeline = RAGPipeline(embeddings=embeddings) - >>> - >>> documents = pipeline.retrieve("戏耍貂蝉美女") - >>> context = pipeline.format_context(documents) + >>> from app.rag import RAGPipeline, create_rag_tool + >>> from rag_indexer.builder import IndexBuilder, IndexBuilderConfig + >>> from langchain_openai import ChatOpenAI + >>> + >>> # 获取基础检索器(如父子块检索器) + >>> config = IndexBuilderConfig(collection_name="my_docs") + >>> builder = IndexBuilder(config) + >>> retriever = builder.retriever + >>> + >>> # 创建 LLM 和流水线 + >>> llm = ChatOpenAI(model="gpt-3.5-turbo") + >>> pipeline = RAGPipeline(retriever=retriever, llm=llm) + >>> + >>> # 检索 + >>> docs = await pipeline.aretrieve("什么是 RAG?") + >>> context = pipeline.format_context(docs) + >>> + >>> # 创建 Agent 工具 + >>> rag_tool = create_rag_tool(retriever=retriever, llm=llm) """ from .retriever import ( create_base_retriever, create_hybrid_retriever, - # create_ensemble_retriever, create_qdrant_client, ) from .reranker import CrossEncoderReranker -from .query_transform import MultiQueryTransformer -from .pipeline import RAGPipeline, RAGLevel -from .tools import search_knowledge_base, search_knowledge_base_sync +from .query_transform import MultiQueryGenerator +from .fusion import reciprocal_rank_fusion +from .pipeline import RAGPipeline +from .tools import create_rag_tool, create_rag_tool_sync __all__ = [ - # 检索器 + # 检索器工厂函数 "create_base_retriever", "create_hybrid_retriever", - # "create_ensemble_retriever", "create_qdrant_client", # 重排序器 "CrossEncoderReranker", - # 查询转换器 - "MultiQueryTransformer", + # 查询改写生成器 + "MultiQueryGenerator", - # 流水线 + # 融合算法 + "reciprocal_rank_fusion", + + # 主流水线 "RAGPipeline", - "RAGLevel", - # 工具 - "search_knowledge_base", - "search_knowledge_base_sync", -] + # 工具创建(供 Agent 使用) + "create_rag_tool", + "create_rag_tool_sync", +] \ No newline at end of file diff --git a/app/rag/example.py b/app/rag/example.py deleted file mode 100644 index 8042b09..0000000 --- a/app/rag/example.py +++ /dev/null @@ -1,203 +0,0 @@ -#!/usr/bin/env python3 -""" -RAG 系统使用示例 - -演示如何使用 app/rag 模块进行知识检索。 -""" - -import sys -import os -from dotenv import load_dotenv - -# 加载环境变量 -load_dotenv() - -# 添加项目根目录到路径 -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) - -from langchain_openai import OpenAIEmbeddings -from langchain_community.llms import VLLMOpenAI - - -def setup_environment(): - """设置环境变量""" - # 设置 Qdrant 连接信息(根据实际情况修改) - os.environ.setdefault("QDRANT_URL", "http://115.190.121.151:6333") - # 设置 Qdrant API 密钥(根据实际情况修改) - os.environ.setdefault("QDRANT_API_KEY", "your-api-key-here") - # 如果需要 API 密钥,请设置 QDRANT_API_KEY - - print("环境变量已设置") - print(f"QDRANT_URL: {os.environ.get('QDRANT_URL')}") - print(f"QDRANT_API_KEY: {'***' if os.environ.get('QDRANT_API_KEY') else '未设置'}") - - -def demonstrate_basic_rag(): - """演示基础 RAG 功能""" - print("\n" + "="*60) - print("演示: 基础 RAG 检索 (Level 1)") - print("="*60) - - # 创建嵌入模型(使用本地 LlamaCpp 模型) - from rag_core import LlamaCppEmbedder - embedder = LlamaCppEmbedder() - embeddings = embedder.as_langchain_embeddings() - - # 创建 RAG 流水线 - from app.rag import RAGPipeline, RAGLevel - - pipeline = RAGPipeline( - embeddings=embeddings, - config={ - "collection_name": "rag_documents", # 你的集合名称 - "rag_level": RAGLevel.BASIC.value, - } - ) - - # 示例查询 - query = "吕布" - print(f"\n查询: {query}") - - try: - documents = pipeline.retrieve(query) - print(f"找到 {len(documents)} 个相关文档") - - # 格式化上下文 - context = pipeline.format_context(documents) - print(f"\n上下文预览:\n{context[:500]}...") - - except Exception as e: - print(f"检索失败: {e}") - print("请确保 Qdrant 服务正常运行且集合存在") - - -def demonstrate_hybrid_rag(): - """演示混合 RAG 功能""" - print("\n" + "="*60) - print("演示: 混合 RAG 检索 (Level 2)") - print("="*60) - - from rag_core import LlamaCppEmbedder - embedder = LlamaCppEmbedder() - embeddings = embedder.as_langchain_embeddings() - - from app.rag import RAGPipeline, RAGLevel - - pipeline = RAGPipeline( - embeddings=embeddings, - config={ - "collection_name": "rag_documents", - "rag_level": RAGLevel.RERANK.value, - "rerank_top_n": 5, - } - ) - - query = "吕布" - print(f"\n查询: {query}") - - try: - documents = pipeline.retrieve(query) - print(f"找到 {len(documents)} 个重排序后的文档") - - # 格式化上下文 - context = pipeline.format_context(documents) - print(f"\n上下文预览:\n{context[:500]}...") - - except Exception as e: - print(f"检索失败: {e}") - - -def demonstrate_rag_fusion(): - """演示 RAG-Fusion 功能""" - print("\n" + "="*60) - print("演示: RAG-Fusion (Level 3)") - print("="*60) - - from rag_core import LlamaCppEmbedder - embedder = LlamaCppEmbedder() - embeddings = embedder.as_langchain_embeddings() - - # 创建语言模型用于查询改写(使用 OpenAI 兼容的本地模型) - from langchain_openai import ChatOpenAI - llm = ChatOpenAI( - openai_api_base="http://localhost:8000/v1", - openai_api_key="no-key-needed", - model="Qwen2.5-7B-Instruct", # 你的本地模型 - temperature=0.3, - max_tokens=512, - ) - - from app.rag import RAGPipeline, RAGLevel - - pipeline = RAGPipeline( - embeddings=embeddings, - llm=llm, - config={ - "collection_name": "rag_documents", - "rag_level": RAGLevel.FUSION.value, - "num_queries": 3, - } - ) - - query = "吕布" - print(f"\n查询: {query}") - - try: - documents = pipeline.retrieve(query) - print(f"找到 {len(documents)} 个文档 (经过多路查询改写和重排序)") - - # 格式化上下文 - context = pipeline.format_context(documents) - print(f"\n上下文预览:\n{context[:500]}...") - - except Exception as e: - print(f"检索失败: {e}") - - -def demonstrate_agentic_rag(): - """演示 Agentic RAG 功能""" - print("\n" + "="*60) - print("演示: Agentic RAG (Level 4)") - print("="*60) - - from app.rag import search_knowledge_base_sync - - try: - # 演示工具调用 - print("工具调用示例:") - response = search_knowledge_base_sync("吕布") - print(f"工具响应预览: {response[:200]}...") - - except Exception as e: - print(f"工具调用失败: {e}") - import traceback - traceback.print_exc() - - -def main(): - """主函数""" - print("RAG 系统演示") - print("="*60) - - # 设置环境 - setup_environment() - - # 演示基础功能 - demonstrate_basic_rag() - demonstrate_hybrid_rag() - # demonstrate_rag_fusion() # 需要本地 LLM 服务 - # demonstrate_agentic_rag() # 需要本地 LLM 服务 - - print("\n" + "="*60) - print("演示完成!") - print("="*60) - - print("\n使用说明:") - print("1. 确保 Qdrant 服务运行且集合已创建") - print("2. 已使用本地 LlamaCpp 嵌入模型") - print("3. 在 Agent 系统中导入并使用 app.rag.tools.search_knowledge_base") - print("4. 将工具绑定到你的 Agent 模型") - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/app/rag/fusion.py b/app/rag/fusion.py new file mode 100644 index 0000000..777cc24 --- /dev/null +++ b/app/rag/fusion.py @@ -0,0 +1,36 @@ +# rag/fusion.py + +from typing import List, Dict, Tuple +from langchain_core.documents import Document + +def reciprocal_rank_fusion( + doc_lists: List[List[Document]], + k: int = 60 +) -> List[Document]: + """ + 对多个检索结果列表进行 RRF 融合。 + + Args: + doc_lists: 多个检索结果列表,每个列表来自一个查询 + k: RRF 常数,通常设为 60 + + Returns: + 融合后按 RRF 得分降序排列的文档列表 + """ + # 使用文档内容作为唯一标识(如果内容相同但 metadata 不同,视为同一文档) + # 更好的做法是用 docstore 的 ID,这里简化处理:用内容 hash + doc_to_score: Dict[str, float] = {} + doc_map: Dict[str, Document] = {} + + for docs in doc_lists: + for rank, doc in enumerate(docs, start=1): + # 生成唯一标识符(内容+来源组合,避免不同文件相同内容混淆) + doc_id = f"{doc.page_content[:200]}_{doc.metadata.get('source', '')}" + if doc_id not in doc_map: + doc_map[doc_id] = doc + score = doc_to_score.get(doc_id, 0.0) + 1.0 / (k + rank) + doc_to_score[doc_id] = score + + # 按得分排序 + sorted_ids = sorted(doc_to_score.keys(), key=lambda x: doc_to_score[x], reverse=True) + return [doc_map[doc_id] for doc_id in sorted_ids] \ No newline at end of file diff --git a/app/rag/pipeline.py b/app/rag/pipeline.py index e5eba7c..c0b4e6f 100644 --- a/app/rag/pipeline.py +++ b/app/rag/pipeline.py @@ -1,168 +1,92 @@ -""" -RAG 检索流水线 +# rag/pipeline.py -整合基础检索、重排序和 RAG-Fusion 功能。 -""" - -from enum import Enum -from typing import List, Optional, Dict, Any +import asyncio +from typing import List, Optional from langchain_core.documents import Document from langchain_core.language_models import BaseLanguageModel -from .retriever import ( - create_base_retriever, - create_hybrid_retriever, - create_qdrant_client, -) -from .reranker import CrossEncoderReranker -from .query_transform import MultiQueryTransformer -from rag_core import QDRANT_URL, QDRANT_API_KEY - - -class RAGLevel(Enum): - """RAG 级别""" - BASIC = "basic" # 基础向量检索 - RERANK = "rerank" # 基础检索 + 重排序 - FUSION = "fusion" # RAG-Fusion(多路查询 + RRF) +from .retriever import create_qdrant_client # 可能不需要直接使用 +from .reranker import LLaMaCPPReranker +from .query_transform import MultiQueryGenerator +from .fusion import reciprocal_rank_fusion class RAGPipeline: - """RAG 检索流水线""" + """ + 固定流程的 RAG 检索流水线: + 多路改写 → 并行检索 → RRF融合 → 重排序 → 返回父文档 + """ def __init__( self, - embeddings, - llm: Optional[BaseLanguageModel] = None, - config: Optional[Dict[str, Any]] = None, + retriever, # 基础检索器(应返回父文档,例如 ParentDocumentRetriever 实例) + llm: BaseLanguageModel, + num_queries: int = 3, + rerank_top_n: int = 5, + rerank_model: str = "BAAI/bge-reranker-base", ): """ - 初始化 RAG 流水线 - Args: - embeddings: 嵌入模型 - llm: 语言模型(用于 RAG-Fusion) - config: 配置参数 + retriever: 基础检索器对象,需实现 ainvoke(query) 异步方法 + llm: 用于生成多路查询的语言模型 + num_queries: 生成的查询变体数量 + rerank_top_n: 最终返回的文档数量 + rerank_model: 重排序模型名称 """ - self.embeddings = embeddings + self.retriever = retriever self.llm = llm - self.config = config or {} + self.num_queries = num_queries + self.rerank_top_n = rerank_top_n - self.collection_name = self.config.get("collection_name", "rag_documents") - self.rag_level = self.config.get("rag_level", RAGLevel.RERANK.value) - self.num_queries = self.config.get("num_queries", 3) - self.rerank_top_n = self.config.get("rerank_top_n", 5) - - # 初始化基础检索器 - self.base_retriever = create_base_retriever( - collection_name=self.collection_name, - embeddings=self.embeddings, - search_kwargs={"k": 20}, # 召回 20 条 - ) - - # 初始化重排序器 - try: - self.reranker = CrossEncoderReranker(top_n=self.rerank_top_n) - except Exception as e: - print(f"警告: 无法创建重排序器,将使用基础检索。错误: {e}") - self.reranker = None - - # 根据 RAG 级别创建检索器 - self.retriever = self._create_retriever() - - def _create_retriever(self): - """根据 RAG 级别创建检索器""" - if self.rag_level == RAGLevel.BASIC.value: - return self.base_retriever - - # 基础检索 + 重排序 - def rerank_retriever(query): - documents = self.base_retriever.invoke(query) - if self.reranker: - return self.reranker.compress_documents(documents, query) - else: - return documents[:self.rerank_top_n] - - if self.rag_level == RAGLevel.RERANK.value: - return SimpleRetriever(rerank_retriever) - - # RAG-Fusion - if self.rag_level == RAGLevel.FUSION.value: - if not self.llm: - raise ValueError("RAG-Fusion 需要提供 llm 参数") - - # 创建多路查询检索器 - transformer = MultiQueryTransformer( - llm=self.llm, - num_queries=self.num_queries + # 初始化组件 + self.query_generator = MultiQueryGenerator(llm=llm, num_queries=num_queries) + self.reranker = LLaMaCPPReranker( + base_url="http://127.0.0.1:8083", + top_n=rerank_top_n, + api_key="huang1998" ) - multi_query_retriever = transformer.create_multi_query_retriever( - base_retriever=SimpleRetriever(rerank_retriever) - ) - - return multi_query_retriever - - return SimpleRetriever(rerank_retriever) - - def retrieve(self, query: str) -> List[Document]: - """ - 执行检索 - - Args: - query: 查询字符串 - - Returns: - 相关文档列表 - """ - return self.retriever.invoke(query) async def aretrieve(self, query: str) -> List[Document]: """ - 异步执行检索 - - Args: - query: 查询字符串 - - Returns: - 相关文档列表 + 异步执行完整检索流程 """ - return await self.retriever.ainvoke(query) + # Step 1: 生成多路查询 + queries = await self.query_generator.agenerate(query) + # 包含原始查询,确保至少有一条 + if query not in queries: + queries.insert(0, query) + else: + # 如果原始查询已在列表中,将其移至首位 + queries.remove(query) + queries.insert(0, query) + + # Step 2: 并行检索(每个查询获取文档列表) + tasks = [self.retriever.ainvoke(q) for q in queries] + doc_lists = await asyncio.gather(*tasks) + + # Step 3: RRF 融合 + fused_docs = reciprocal_rank_fusion(doc_lists) + + # Step 4: 重排序 + if self.reranker.model is not None: + final_docs = self.reranker.compress_documents(fused_docs, query) + else: + # 若重排序器不可用,直接返回融合后的前 N 条 + final_docs = fused_docs[:self.rerank_top_n] + + return final_docs + + def retrieve(self, query: str) -> List[Document]: + """同步检索入口(内部调用异步方法)""" + return asyncio.run(self.aretrieve(query)) def format_context(self, documents: List[Document]) -> str: - """ - 格式化上下文 - - Args: - documents: 文档列表 - - Returns: - 格式化后的上下文字符串 - """ + """将文档列表格式化为上下文字符串""" if not documents: return "" - context_parts = [] + parts = [] for i, doc in enumerate(documents, 1): - content = doc.page_content - metadata = doc.metadata or {} - source = metadata.get("source", "未知来源") - - part = f"【资料 {i}】\n" - part += f"来源: {source}\n" - part += f"内容: {content}\n" - part += "---\n" - context_parts.append(part) - - return "".join(context_parts) - - -class SimpleRetriever: - """简单检索器包装类""" - - def __init__(self, retrieve_func): - self.retrieve_func = retrieve_func - - def invoke(self, query): - return self.retrieve_func(query) - - async def ainvoke(self, query): - return self.retrieve_func(query) + source = doc.metadata.get("source", "未知来源") + parts.append(f"【资料 {i}】来源:{source}\n{doc.page_content}\n---\n") + return "\n".join(parts) \ No newline at end of file diff --git a/app/rag/query_transform.py b/app/rag/query_transform.py index 5183f9e..38f9fd1 100644 --- a/app/rag/query_transform.py +++ b/app/rag/query_transform.py @@ -1,62 +1,43 @@ -""" -查询转换器模块 +# rag/query_transform.py -实现多路查询改写功能,用于 RAG-Fusion。 -""" - -from typing import List, Optional +from typing import List from langchain_core.language_models import BaseLanguageModel -# from langchain.retrievers.multi_query import MultiQueryRetriever from langchain_core.prompts import PromptTemplate +MULTI_QUERY_PROMPT = PromptTemplate.from_template( + """你是一个专业的查询改写助手。你的任务是将用户的问题改写成 {num_queries} 个不同的版本。 +这些版本应该从不同的角度、使用不同的关键词来表达相同或相关的意图。 -class MultiQueryTransformer: - """多路查询改写器,用于 RAG-Fusion。""" +原始问题: {question} + +请生成 {num_queries} 个不同版本的查询,每个版本一行。 +确保每个版本都是独立、完整的查询语句。 + +生成 {num_queries} 个查询:""" +) + +class MultiQueryGenerator: + """多路查询生成器(不依赖 LangChain 的 MultiQueryRetriever)""" def __init__(self, llm: BaseLanguageModel, num_queries: int = 3): - """ - 初始化多路查询改写器。 - - Args: - llm: 语言模型实例 - num_queries: 生成的查询数量 - """ self.llm = llm self.num_queries = num_queries + self.prompt = MULTI_QUERY_PROMPT - def create_multi_query_retriever(self, base_retriever): - """ - 创建多路查询检索器。 - - Args: - base_retriever: 基础检索器 - - Returns: - MultiQueryRetriever 实例 - """ - # 由于当前 LangChain 版本不支持 MultiQueryRetriever,暂时返回基础检索器 - # retriever = MultiQueryRetriever.from_llm( - # retriever=base_retriever, - # llm=self.llm, - # include_original=True - # ) - # - # # 自定义提示词 - # retriever.llm_chain.prompt = PromptTemplate.from_template( - # "你是一个专业的查询改写助手。你的任务是将用户的问题改写成 {num_queries} 个不同的版本。\n" - # "这些版本应该从不同的角度、使用不同的关键词来表达相同或相关的意图。\n\n" - # "原始问题: {question}\n\n" - # "请生成 {num_queries} 个不同版本的查询,每个版本一行。\n" - # "确保每个版本都是独立、完整的查询语句。\n\n" - # "生成 {num_queries} 个查询:" - # ) - # - # # 修改调用参数以包含 num_queries - # original_ainvoke = retriever.llm_chain.ainvoke - # async def new_ainvoke(input_dict): - # input_dict["num_queries"] = self.num_queries - # return await original_ainvoke(input_dict) - # retriever.llm_chain.ainvoke = new_ainvoke - # - # return retriever - return base_retriever + def generate(self, query: str) -> List[str]: + """同步生成多个查询变体""" + prompt_str = self.prompt.format(num_queries=self.num_queries, question=query) + response = self.llm.invoke(prompt_str) + # 处理响应内容,按行分割并去除空行和首尾空白 + lines = response.content.strip().split('\n') + queries = [line.strip() for line in lines if line.strip()] + # 确保至少返回原始查询 + return queries[:self.num_queries] if queries else [query] + + async def agenerate(self, query: str) -> List[str]: + """异步生成多个查询变体""" + prompt_str = self.prompt.format(num_queries=self.num_queries, question=query) + response = await self.llm.ainvoke(prompt_str) + lines = response.content.strip().split('\n') + queries = [line.strip() for line in lines if line.strip()] + return queries[:self.num_queries] if queries else [query] \ No newline at end of file diff --git a/app/rag/reranker.py b/app/rag/reranker.py index 4a414cf..7a53806 100644 --- a/app/rag/reranker.py +++ b/app/rag/reranker.py @@ -1,35 +1,34 @@ """ -重排序器模块 - -使用 Cross-Encoder 模型对检索结果进行重排序,提高检索精度。 +重排序器模块 (适配版) +使用远程 llama.cpp 服务 (兼容 OpenAI Rerank API) 替代本地 Cross-Encoder """ - +import requests from typing import List from langchain_core.documents import Document +class LLaMaCPPReranker: + """使用远程 llama.cpp 服务对检索结果重排序。""" -class CrossEncoderReranker: - """使用 Cross-Encoder 对检索结果重排序。""" - - def __init__(self, model_name: str = "BAAI/bge-reranker-base", top_n: int = 5): + def __init__(self, + base_url: str = "http://127.0.0.1:8083", + top_n: int = 5, + api_key: str = "huang1998", # 你设置的 LLAMA_ARG_API_KEY + timeout: int = 60): """ - 初始化重排序器 + 初始化远程重排序器 Args: - model_name: 预训练模型名称 - top_n: 返回前 N 个结果 + base_url: llama.cpp 服务的地址和端口。 + top_n: 返回前 N 个结果。 + api_key: 在容器中设置的 API 密钥。 + timeout: 请求超时时间(秒)。 """ - self.model_name = model_name + self.base_url = base_url.rstrip('/') self.top_n = top_n - self.model = None + self.api_key = api_key + self.timeout = timeout + self.endpoint = f"{self.base_url}/v1/rerank" - # 尝试加载 Cross-Encoder 模型 - try: - from sentence_transformers import CrossEncoder - self.model = CrossEncoder(model_name) - except Exception as e: - print(f"警告: 无法加载 Cross-Encoder 模型 {model_name},将使用简单排序作为回退方案。错误: {e}") - def compress_documents( self, documents: List[Document], query: str ) -> List[Document]: @@ -45,21 +44,32 @@ class CrossEncoderReranker: """ if not documents: return [] - - # 如果模型加载失败,返回前 top_n 个文档 - if self.model is None: - return documents[:self.top_n] - - # 使用 Cross-Encoder 进行重排序 + + # 准备请求体 + # 根据 llama.cpp 的 OpenAI 兼容性,文档是一个字符串列表 + payload = { + "model": "bge-reranker-v2-m3", + "query": query, + "documents": [doc.page_content for doc in documents], + "top_n": self.top_n + } + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.api_key}" + } + try: - pairs = [[query, doc.page_content] for doc in documents] - scores = self.model.predict(pairs) + response = requests.post(self.endpoint, json=payload, headers=headers, timeout=self.timeout) + response.raise_for_status() # 检查请求是否成功 + results = response.json() + + # 解析返回结果 + # 返回格式: {"results": [{"index": 0, "document": "...", "relevance_score": 0.8}, ...]} + # 按相关性得分降序排列 + sorted_indices = [item["index"] for item in results["results"]] + sorted_docs = [documents[idx] for idx in sorted_indices] + return sorted_docs - # 按分数降序排序 - scored_docs = sorted( - zip(documents, scores), key=lambda x: x[1], reverse=True - ) - return [doc for doc, _ in scored_docs[:self.top_n]] except Exception as e: - print(f"警告: 重排序过程出错,将使用原始排序。错误: {e}") - return documents[:self.top_n] + print(f"警告: 远程重排序过程出错,将使用原始排序。错误: {e}") + return documents[:self.top_n] \ No newline at end of file diff --git a/app/rag/retriever.py b/app/rag/retriever.py index 80d6284..483c8b9 100644 --- a/app/rag/retriever.py +++ b/app/rag/retriever.py @@ -1,39 +1,83 @@ """ -Qdrant 向量检索器 +Qdrant 向量检索器模块 -提供基础向量检索、混合检索(Dense + BM25)功能。 +提供基于 Qdrant 的基础向量检索和混合检索(Dense + Sparse)功能。 + +核心原理: +- 基础检索:将查询文本转换为向量,在 Qdrant 中进行近似最近邻(ANN)搜索, + 使用余弦相似度返回最相似的 k 个文档。 +- 混合检索:结合稠密向量检索(语义相似)和 BM25 稀疏向量检索(关键词匹配), + 通过加权或分数融合提高召回精度。 + +使用示例: + >>> from rag_core import LlamaCppEmbedder + >>> embedder = LlamaCppEmbedder() + >>> embeddings = embedder.as_langchain_embeddings() + >>> + >>> # 创建基础检索器 + >>> retriever = create_base_retriever( + ... collection_name="my_docs", + ... embeddings=embeddings, + ... search_kwargs={"k": 10} + ... ) + >>> + >>> # 执行检索 + >>> docs = retriever.invoke("什么是 RAG?") """ -from typing import List, Dict, Any, Optional -from langchain_qdrant import QdrantVectorStore -from langchain.embeddings.base import Embeddings -# from langchain.retrievers import EnsembleRetriever +from typing import Optional, Dict, Any from qdrant_client import QdrantClient +from qdrant_client.http.exceptions import UnexpectedResponse +from langchain_qdrant import QdrantVectorStore +from langchain_core.embeddings import Embeddings +from langchain_core.retrievers import BaseRetriever + from rag_core import QDRANT_URL, QDRANT_API_KEY +# 模块级常量 +DEFAULT_SEARCH_K = 20 +DEFAULT_SCORE_THRESHOLD = 0.3 + def create_qdrant_client( url: Optional[str] = None, api_key: Optional[str] = None, + timeout: int = 30, ) -> QdrantClient: """ - 创建 Qdrant 客户端 + 创建并返回一个配置好的 Qdrant 客户端。 + + 优先使用传入参数,若未提供则回退到环境变量 QDRANT_URL 和 QDRANT_API_KEY。 Args: - url: Qdrant 服务地址,默认从环境变量 QDRANT_URL 读取 - api_key: API 密钥,默认从环境变量 QDRANT_API_KEY 读取 + url: Qdrant 服务地址,例如 "http://localhost:6333"。 + 默认从环境变量 QDRANT_URL 读取。 + api_key: API 密钥(若 Qdrant 启用了认证)。 + 默认从环境变量 QDRANT_API_KEY 读取。 + timeout: 请求超时时间(秒),默认 30 秒。 Returns: - QdrantClient 实例 + 配置好的 QdrantClient 实例。 + + Raises: + ValueError: 如果 url 为空且环境变量也未设置。 """ - url = url or QDRANT_URL - api_key = api_key or QDRANT_API_KEY + effective_url = url or QDRANT_URL + if not effective_url: + raise ValueError( + "Qdrant URL 未提供,请设置参数 url 或环境变量 QDRANT_URL" + ) - client_args = {"url": url} - if api_key: - client_args["api_key"] = api_key + effective_api_key = api_key or QDRANT_API_KEY - return QdrantClient(**client_args) + client_kwargs = { + "url": effective_url, + "timeout": timeout, + } + if effective_api_key: + client_kwargs["api_key"] = effective_api_key + + return QdrantClient(**client_kwargs) def create_base_retriever( @@ -41,33 +85,57 @@ def create_base_retriever( embeddings: Embeddings, search_kwargs: Optional[Dict[str, Any]] = None, client: Optional[QdrantClient] = None, -) -> QdrantVectorStore: +) -> BaseRetriever: """ - 创建基础向量检索器 + 创建基础向量检索器(仅稠密向量检索)。 + + 该检索器使用嵌入模型将查询转为向量,在 Qdrant 集合中执行 ANN 搜索, + 返回语义上最相似的文档块。 Args: - collection_name: Qdrant 集合名称 - embeddings: 嵌入模型 - search_kwargs: 搜索参数,默认 {"k": 20} - client: Qdrant 客户端,如果为 None 则自动创建 + collection_name: Qdrant 集合名称(需预先创建并索引)。 + embeddings: LangChain 兼容的嵌入模型实例。 + search_kwargs: 搜索参数,可包含: + - k (int): 返回的文档数量,默认 20。 + - score_threshold (float): 相似度阈值,仅返回高于此分数的文档。 + - filter (dict): Qdrant 过滤条件。 + 若为 None,则使用默认值 {"k": 20}。 + client: 可选的 Qdrant 客户端实例。若未提供,将自动创建。 Returns: - QdrantVectorStore 检索器实例 - """ - search_kwargs = search_kwargs or {"k": 20} + BaseRetriever 实例,可直接调用 .invoke(query) 或 .ainvoke(query) 检索。 - # 创建 Qdrant 客户端 + Raises: + ValueError: 如果集合不存在或嵌入模型无效。 + """ + # 合并默认搜索参数 + merged_search_kwargs = {"k": DEFAULT_SEARCH_K} + if search_kwargs: + merged_search_kwargs.update(search_kwargs) + + # 创建或复用 Qdrant 客户端 if client is None: client = create_qdrant_client() - # 使用 QdrantVectorStore 创建向量存储 + # 验证集合是否存在(可选,便于提前发现问题) + try: + client.get_collection(collection_name) + except UnexpectedResponse as e: + if e.status_code == 404: + raise ValueError( + f"Qdrant 集合 '{collection_name}' 不存在,请先创建并索引文档。" + ) + raise + + # 构建向量存储 vector_store = QdrantVectorStore( client=client, collection_name=collection_name, embedding=embeddings, ) - return vector_store.as_retriever(search_kwargs=search_kwargs) + # 返回检索器 + return vector_store.as_retriever(search_kwargs=merged_search_kwargs) def create_hybrid_retriever( @@ -75,64 +143,57 @@ def create_hybrid_retriever( embeddings: Embeddings, dense_k: int = 10, sparse_k: int = 10, + score_threshold: Optional[float] = DEFAULT_SCORE_THRESHOLD, client: Optional[QdrantClient] = None, -) -> QdrantVectorStore: +) -> BaseRetriever: """ - 创建混合检索器(Dense Vector + BM25) + 创建混合检索器(稠密向量 + BM25 稀疏向量)。 + + 混合检索结合了语义相似度(Dense)和关键词匹配(Sparse), + 能够更好地处理专有名词、精确匹配等场景。 + + 注意:此功能要求 Qdrant 集合已配置稀疏向量字段并生成了 BM25 索引。 + 若集合未配置稀疏向量,将回退到纯稠密检索(不会报错,但检索效果降级)。 Args: - collection_name: Qdrant 集合名称 - embeddings: 嵌入模型 - dense_k: 向量检索返回数量 - sparse_k: BM25 检索返回数量 - client: Qdrant 客户端 + collection_name: Qdrant 集合名称。 + embeddings: 嵌入模型(用于稠密向量)。 + dense_k: 稠密向量检索返回数量,默认 10。 + sparse_k: 稀疏向量检索返回数量,默认 10。 + score_threshold: 相似度阈值,默认 0.3。 + client: 可选的 Qdrant 客户端实例。 Returns: - 混合检索器 + BaseRetriever 实例,配置了混合搜索参数。 """ - # 创建 Qdrant 客户端 - if client is None: - client = create_qdrant_client() - - # 使用 QdrantVectorStore 创建向量存储 - vector_store = QdrantVectorStore( - client=client, - collection_name=collection_name, - embedding=embeddings, - ) + total_k = dense_k + sparse_k search_kwargs = { - "k": dense_k + sparse_k, - "score_threshold": 0.3, + "k": total_k, } + if score_threshold is not None: + search_kwargs["score_threshold"] = score_threshold - return vector_store.as_retriever(search_kwargs=search_kwargs) + # 复用基础检索器创建逻辑,只需调整搜索参数 + return create_base_retriever( + collection_name=collection_name, + embeddings=embeddings, + search_kwargs=search_kwargs, + client=client, + ) -# def create_ensemble_retriever( -# retrievers: List[Any], -# weights: Optional[List[float]] = None, -# c: int = 60, -# ) -> EnsembleRetriever: -# """ -# 创建集成检索器,支持倒数排名融合 (RRF) -# -# Args: -# retrievers: 检索器列表 -# weights: 检索器权重 -# c: RRF 常数(通常为60) -# -# Returns: -# 集成检索器 -# """ -# if weights is None: -# weights = [1.0 / len(retrievers)] * len(retrievers) -# -# ensemble = EnsembleRetriever( -# retrievers=retrievers, -# weights=weights, -# c=c, -# search_type="rrf", -# ) -# -# return ensemble +# 可选:提供异步友好的辅助函数 +async def acreate_base_retriever( + collection_name: str, + embeddings: Embeddings, + search_kwargs: Optional[Dict[str, Any]] = None, + client: Optional[QdrantClient] = None, +) -> BaseRetriever: + """ + 异步创建基础向量检索器(与同步版本功能相同)。 + + 适用于需要异步初始化的场景(例如在 FastAPI 启动事件中)。 + """ + # 由于 QdrantVectorStore 初始化本身是同步的,这里直接调用同步版本即可 + return create_base_retriever(collection_name, embeddings, search_kwargs, client) \ No newline at end of file diff --git a/app/rag/test.py b/app/rag/test.py new file mode 100644 index 0000000..80d2255 --- /dev/null +++ b/app/rag/test.py @@ -0,0 +1,153 @@ +#!/usr/bin/env python3 +""" +RAG 系统使用示例(重构版) + +演示: +1. 使用 IndexBuilder 获取父子块检索器 +2. 创建固定流程的 RAGPipeline(多路改写 → RRF融合 → 重排序 → 返回父文档) +3. 将流水线封装为 LangChain 工具,供 Agent 调用 +""" + +import asyncio +import sys +import os +from pathlib import Path + +from dotenv import load_dotenv + +# 加载环境变量(Qdrant URL、PostgreSQL 连接等) +load_dotenv() + +# 添加项目根目录到路径 +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +from rag_indexer.index_builder import IndexBuilder, IndexBuilderConfig +from rag_indexer.splitters import SplitterType +from rag.pipeline import RAGPipeline +from rag.tools import create_rag_tool +from pydantic import SecretStr +# 使用本地 LLM(通过 OpenAI 兼容接口) +from langchain_openai import ChatOpenAI +from rag_core.retriever_factory import create_parent_retriever + +load_dotenv() + + +def create_llm(): + """创建本地 vLLM 服务 LLM""" + vllm_base_url = os.getenv( + "VLLM_BASE_URL", + "http://127.0.0.1:8081/v1" + ) + + return ChatOpenAI( + base_url=vllm_base_url, + api_key=SecretStr(os.getenv("LLAMACPP_API_KEY", "token-abc123")), + model="gemma-4-E2B-it", + timeout=60.0, # 请求超时时间(秒) + max_retries=2, # 失败后自动重试次数 + streaming=True, # 确保开启流式输出 + ) + +async def demonstrate_full_pipeline(): + """ + 完整流水线演示: + - 从 IndexBuilder 获取 ParentDocumentRetriever + - 创建 RAGPipeline + - 执行检索并打印结果 + """ + print("=" * 60) + print("演示:固定流程 RAG 检索(多路改写 + RRF + 重排序 + 父文档)") + print("=" * 60) + + + retriever = retriever = create_parent_retriever(collection_name="my_docs", search_k=5) + + if retriever is None: + print("错误:检索器未初始化,请确保索引已构建。") + return + + # 3. 创建 LLM 用于查询改写 + llm = create_llm() + + # 4. 创建 RAGPipeline(固定流程) + pipeline = RAGPipeline( + retriever=retriever, + llm=llm, + num_queries=3, # 生成 3 个查询变体 + rerank_top_n=5, # 最终返回 5 个父文档 + ) + + # 5. 执行检索 + query = "打虎英雄是谁?" + print(f"\n查询: {query}") + print("-" * 40) + + try: + documents = await pipeline.aretrieve(query) + print(f"返回 {len(documents)} 个父文档\n") + + # 打印结果预览 + for i, doc in enumerate(documents, 1): + content_preview = doc.page_content.replace("\n", " ")[:150] + source = doc.metadata.get("source", "未知来源") + print(f"{i}. 【来源:{source}】") + print(f" {content_preview}...\n") + + # 可选:格式化完整上下文 + # context = pipeline.format_context(documents) + # print(context) + + except Exception as e: + print(f"检索失败: {e}") + import traceback + traceback.print_exc() + + +async def demonstrate_tool_creation(): + """ + 演示创建 RAG 工具(供 Agent 使用) + """ + print("\n" + "=" * 60) + print("演示:创建 RAG 工具(供 LangGraph Agent 调用)") + print("=" * 60) + + # 1. 获取检索器(同上) + config = IndexBuilderConfig( + collection_name="rag_documents", + splitter_type=SplitterType.PARENT_CHILD, + ) + retriever = retriever = create_parent_retriever(collection_name="rag_documents", search_k=5) + + + # 2. 创建 LLM + llm = create_llm() + + # 3. 创建工具 + rag_tool = create_rag_tool( + retriever=retriever, + llm=llm, + num_queries=3, + rerank_top_n=5, + collection_name="rag_documents", + ) + + print(f"工具名称: {rag_tool.name}") + print(f"工具描述: {rag_tool.description[:100]}...") + + # 4. 模拟 Agent 调用工具 + query = "请告诉我 RAG 系统的核心组件有哪些?" + print(f"\n模拟调用: {query}") + print("-" * 40) + + result = await rag_tool.ainvoke({"query": query}) + print(result[:800] + "..." if len(result) > 800 else result) + + +async def main(): + await demonstrate_full_pipeline() + await demonstrate_tool_creation() + + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/app/rag/tools.py b/app/rag/tools.py index a284a11..32268ed 100644 --- a/app/rag/tools.py +++ b/app/rag/tools.py @@ -2,88 +2,115 @@ RAG 工具模块 将检索功能封装为 LangChain Tool,供 Agent 调用。 +采用固定流水线:多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档。 """ +from typing import Optional, Callable from langchain_core.tools import tool -from rag_core import LlamaCppEmbedder, QDRANT_URL, QDRANT_API_KEY -from .pipeline import RAGPipeline, RAGLevel +from langchain_core.language_models import BaseLanguageModel +from langchain_core.retrievers import BaseRetriever + +from .pipeline import RAGPipeline -@tool -async def search_knowledge_base(query: str, rag_level: str = "rerank") -> str: - """在知识库中搜索与查询相关的文档片段。 - - 适用于事实性问题、背景知识查询。 - - Args: - query: 查询字符串 - rag_level: 检索级别,可选值:basic(基础向量检索)、rerank(基础检索+重排序)、fusion(RAG-Fusion) - - Returns: - 检索到的相关文档内容 +def create_rag_tool( + retriever: BaseRetriever, + llm: BaseLanguageModel, + num_queries: int = 3, + rerank_top_n: int = 5, + collection_name: str = "rag_documents", +) -> Callable: """ - # 初始化嵌入模型 - embedder = LlamaCppEmbedder() - embeddings = embedder.as_langchain_embeddings() - - # 创建 RAG 流水线 - pipeline = RAGPipeline( - embeddings=embeddings, - config={ - "rag_level": rag_level, - "collection_name": "rag_documents", - "rerank_top_n": 5, - } - ) - - # 执行检索 - try: - documents = await pipeline.aretrieve(query) - if not documents: - return "未找到相关信息。" - - # 格式化结果 - context = pipeline.format_context(documents) - return context - except Exception as e: - return f"检索过程中发生错误: {str(e)}" + 创建一个配置好的 RAG 检索工具(异步)。 - -@tool -def search_knowledge_base_sync(query: str, rag_level: str = "rerank") -> str: - """同步版本的知识库搜索工具。 - - 适用于事实性问题、背景知识查询。 - Args: - query: 查询字符串 - rag_level: 检索级别,可选值:basic(基础向量检索)、rerank(基础检索+重排序)、fusion(RAG-Fusion) - + retriever: 基础检索器(例如 ParentDocumentRetriever 实例) + llm: 用于多路查询改写的语言模型 + num_queries: 生成查询变体数量 + rerank_top_n: 最终返回的文档数量 + collection_name: 集合名称(仅用于日志/描述) + Returns: - 检索到的相关文档内容 + LangChain Tool 可调用对象(异步) """ - # 初始化嵌入模型 - embedder = LlamaCppEmbedder() - embeddings = embedder.as_langchain_embeddings() - - # 创建 RAG 流水线 + # 初始化流水线(所有组件一次创建,后续复用) pipeline = RAGPipeline( - embeddings=embeddings, - config={ - "rag_level": rag_level, - "collection_name": "rag_documents", - "rerank_top_n": 5, - } + retriever=retriever, + llm=llm, + num_queries=num_queries, + rerank_top_n=rerank_top_n, ) - - # 执行检索 - try: - documents = pipeline.retrieve(query) - if not documents: - return "未找到相关信息。" - - # 格式化结果 - context = pipeline.format_context(documents) - return context - except Exception as e: - return f"检索过程中发生错误: {str(e)}" + + @tool + async def search_knowledge_base(query: str) -> str: + """在知识库中搜索与查询相关的文档片段。 + + 该工具会: + 1. 将用户问题改写成多个不同角度的查询 + 2. 并行检索每个查询的相关父文档 + 3. 使用倒数排名融合(RRF)合并结果 + 4. 用 Cross-Encoder 重排序模型精选最相关的片段 + + 适用于需要精确、全面答案的事实性问题或背景知识查询。 + + Args: + query: 用户提出的问题或查询字符串 + + Returns: + 格式化后的相关文档内容,若无结果则返回提示信息。 + """ + try: + documents = await pipeline.aretrieve(query) + if not documents: + return f"在知识库 '{collection_name}' 中未找到与 '{query}' 相关的信息。" + + context = pipeline.format_context(documents) + return context + except Exception as e: + return f"检索过程中发生错误: {str(e)}" + + return search_knowledge_base + + +def create_rag_tool_sync( + retriever: BaseRetriever, + llm: BaseLanguageModel, + num_queries: int = 3, + rerank_top_n: int = 5, + collection_name: str = "rag_documents", +) -> Callable: + """ + 创建一个配置好的 RAG 检索工具(同步版本,用于不支持异步的旧版 Agent)。 + + 参数同 create_rag_tool。 + """ + pipeline = RAGPipeline( + retriever=retriever, + llm=llm, + num_queries=num_queries, + rerank_top_n=rerank_top_n, + ) + + @tool + def search_knowledge_base_sync(query: str) -> str: + """在知识库中搜索与查询相关的文档片段(同步版本)。 + + 功能与异步版本相同:多路改写 → RRF融合 → 重排序 → 返回父文档。 + + Args: + query: 用户提出的问题或查询字符串 + + Returns: + 格式化后的相关文档内容。 + """ + try: + documents = pipeline.retrieve(query) # 内部调用异步方法并等待 + if not documents: + return f"在知识库 '{collection_name}' 中未找到与 '{query}' 相关的信息。" + + context = pipeline.format_context(documents) + return context + except Exception as e: + return f"检索过程中发生错误: {str(e)}" + + return search_knowledge_base_sync \ No newline at end of file diff --git a/rag_core/__init__.py b/rag_core/__init__.py index c6aa7a6..318a066 100644 --- a/rag_core/__init__.py +++ b/rag_core/__init__.py @@ -7,6 +7,8 @@ RAG Core - 公共 RAG 组件包 from .embedders import LlamaCppEmbedder from .vector_store import QdrantVectorStore, QDRANT_URL, QDRANT_API_KEY from .store import PostgresDocStore, create_docstore +from .retriever_factory import create_parent_retriever + __all__ = [ "LlamaCppEmbedder", @@ -15,4 +17,5 @@ __all__ = [ "QDRANT_API_KEY", "PostgresDocStore", "create_docstore", + "create_parent_retriever", ] diff --git a/rag_core/client.py b/rag_core/client.py new file mode 100644 index 0000000..3f313ca --- /dev/null +++ b/rag_core/client.py @@ -0,0 +1,24 @@ +# rag_core/client.py +import os +from typing import Optional +from qdrant_client import QdrantClient + +QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333") +QDRANT_API_KEY = os.getenv("QDRANT_API_KEY") + +def create_qdrant_client( + url: Optional[str] = None, + api_key: Optional[str] = None, + timeout: int = 120, # 索引构建需要较长超时 +) -> QdrantClient: + effective_url = url or QDRANT_URL + effective_api_key = api_key or QDRANT_API_KEY + + if not effective_url: + raise ValueError("Qdrant URL 未配置") + + client_kwargs = {"url": effective_url, "timeout": timeout} + if effective_api_key: + client_kwargs["api_key"] = effective_api_key + + return QdrantClient(**client_kwargs) \ No newline at end of file diff --git a/rag_core/retriever_factory.py b/rag_core/retriever_factory.py new file mode 100644 index 0000000..24a77af --- /dev/null +++ b/rag_core/retriever_factory.py @@ -0,0 +1,67 @@ +# rag_core/retriever_factory.py +from langchain_core.embeddings import Embeddings +from langchain_classic.retrievers import ParentDocumentRetriever +from langchain_text_splitters import RecursiveCharacterTextSplitter +from rag_indexer.splitters import SplitterType, get_splitter +import asyncio +import logging +from dataclasses import dataclass, field +from pathlib import Path +from typing import List, Union, Optional, Any, Dict, Tuple +from httpx import RemoteProtocolError +from langchain_core.documents import Document +from langchain_core.embeddings import Embeddings +from langchain_core.stores import BaseStore +from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter +from langchain_classic.retrievers import ParentDocumentRetriever + +from rag_core import LlamaCppEmbedder, QdrantVectorStore, create_docstore + + +def create_parent_retriever( + collection_name: str = "rag_documents", + embeddings: Optional[Embeddings] = None, + parent_splitter: Optional[TextSplitter] = None, + child_splitter: Optional[TextSplitter] = None, + docstore: Optional[BaseStore] = None, + search_k: int = 5, + # 若未传入切分器,则用以下参数创建默认切分器 + parent_chunk_size: int = 1000, + parent_chunk_overlap: int = 100, + child_chunk_size: int = 200, + child_chunk_overlap: int = 20, +) -> ParentDocumentRetriever: + # 嵌入模型 + if embeddings is None: + embedder = LlamaCppEmbedder() + embeddings = embedder.as_langchain_embeddings() + + # 向量存储(只读) + vector_store = QdrantVectorStore( + collection_name=collection_name, + embeddings=embeddings, + ) + + # 切分器(若未提供则创建默认) + if parent_splitter is None: + parent_splitter = RecursiveCharacterTextSplitter( + chunk_size=parent_chunk_size, + chunk_overlap=parent_chunk_overlap, + ) + if child_splitter is None: + child_splitter = RecursiveCharacterTextSplitter( + chunk_size=child_chunk_size, + chunk_overlap=child_chunk_overlap, + ) + + # 文档存储 + if docstore is None: + docstore, _ = create_docstore() # 从环境变量读取连接 + + return ParentDocumentRetriever( + vectorstore=vector_store.get_langchain_vectorstore(), + docstore=docstore, + child_splitter=child_splitter, + parent_splitter=parent_splitter, + search_kwargs={"k": search_k}, + ) \ No newline at end of file diff --git a/rag_core/vector_store.py b/rag_core/vector_store.py index 5faa66f..7fd3080 100644 --- a/rag_core/vector_store.py +++ b/rag_core/vector_store.py @@ -10,6 +10,7 @@ from langchain_core.documents import Document from langchain_qdrant import QdrantVectorStore as LangchainQdrantVS from qdrant_client import QdrantClient from qdrant_client.http.models import Distance, VectorParams +from .client import create_qdrant_client logger = logging.getLogger(__name__) @@ -44,14 +45,8 @@ class QdrantVectorStore: ) def get_client(self) -> QdrantClient: - """懒加载客户端,每次获取时确保连接可用。""" if self._client is None: - self._client = QdrantClient( - url=QDRANT_URL, - api_key=QDRANT_API_KEY, - timeout=120, - http2=False, - ) + self._client = create_qdrant_client(timeout=120) return self._client def refresh_client(self): diff --git a/rag_indexer/__init__.py b/rag_indexer/__init__.py index 7d178ac..2a0117f 100644 --- a/rag_indexer/__init__.py +++ b/rag_indexer/__init__.py @@ -23,7 +23,7 @@ Offline RAG Indexer module. >>> await builder.build_from_file("document.pdf") """ -from .IndexBuilder import IndexBuilder, IndexBuilderConfig, DocstoreConfig +from .index_builder import IndexBuilder, IndexBuilderConfig, DocstoreConfig from .loaders import DocumentLoader from .splitters import SplitterType, get_splitter @@ -39,7 +39,7 @@ __version__ = "2.0.0" __all__ = [ # 核心构建器与配置 - "IndexBuilder", + "index_builder", "IndexBuilderConfig", "DocstoreConfig", diff --git a/rag_indexer/cli.py b/rag_indexer/cli.py index 6942506..1ecc15d 100755 --- a/rag_indexer/cli.py +++ b/rag_indexer/cli.py @@ -7,7 +7,7 @@ import logging import sys from pathlib import Path -from rag_indexer.IndexBuilder import IndexBuilder, IndexBuilderConfig +from rag_indexer.index_builder import IndexBuilder, IndexBuilderConfig from rag_indexer.splitters import SplitterType logging.basicConfig( diff --git a/rag_indexer/IndexBuilder.py b/rag_indexer/index_builder.py similarity index 93% rename from rag_indexer/IndexBuilder.py rename to rag_indexer/index_builder.py index 6f077e9..a585970 100644 --- a/rag_indexer/IndexBuilder.py +++ b/rag_indexer/index_builder.py @@ -19,7 +19,8 @@ from langchain_classic.retrievers import ParentDocumentRetriever from .loaders import DocumentLoader from .splitters import SplitterType, get_splitter, SemanticChunkerAdapter -from rag_core import LlamaCppEmbedder, QdrantVectorStore, create_docstore +from rag_core import LlamaCppEmbedder, QdrantVectorStore, create_docstore, create_parent_retriever + logger = logging.getLogger(__name__) @@ -113,43 +114,40 @@ class IndexBuilder: logger.info("使用单一 %s 切分器", self.config.splitter_type.value) def _init_parent_child_mode(self) -> None: - """父子块切分模式,初始化父块/子块切分器、文档存储和检索器。""" cfg = self.config - # 父块切分器(始终使用递归切分) + # 父块切分器(索引构建需要,必须保留) self.parent_splitter = RecursiveCharacterTextSplitter( chunk_size=cfg.parent_chunk_size, chunk_overlap=cfg.parent_chunk_overlap, ) - # 子块切分器 + # 子块切分器(索引构建需要) if cfg.child_splitter_type == SplitterType.SEMANTIC: self.child_splitter = get_splitter( SplitterType.SEMANTIC, embeddings=self.embeddings, **cfg.extra_splitter_kwargs ) - logger.info("子块使用语义切分器") else: self.child_splitter = RecursiveCharacterTextSplitter( chunk_size=cfg.child_chunk_size, chunk_overlap=cfg.child_chunk_overlap, ) - logger.info("子块使用递归切分器,块大小=%d,重叠=%d", - cfg.child_chunk_size, cfg.child_chunk_overlap) - # 初始化文档存储(用于父块) + # 文档存储 self.docstore = self._create_or_use_docstore() - # 创建检索器 - self.retriever = ParentDocumentRetriever( - vectorstore=self.vector_store.get_langchain_vectorstore(), - docstore=self.docstore, - child_splitter=self.child_splitter, # type: ignore[arg-type] + # 使用工厂函数创建检索器,避免重复代码 + self.retriever = create_parent_retriever( + collection_name=cfg.collection_name, + embeddings=self.embeddings, parent_splitter=self.parent_splitter, - search_kwargs={"k": cfg.search_k}, + child_splitter=self.child_splitter, + docstore=self.docstore, + search_k=cfg.search_k, ) - logger.info("ParentDocumentRetriever 初始化完成,父块大小=%d", cfg.parent_chunk_size) + logger.info("ParentDocumentRetriever 初始化完成") def _create_or_use_docstore(self) -> BaseStore: """创建或获取文档存储实例。""" diff --git a/rag_indexer/test/test_refactored.py b/rag_indexer/test/test_refactored.py index ca681d9..f52cc9a 100644 --- a/rag_indexer/test/test_refactored.py +++ b/rag_indexer/test/test_refactored.py @@ -10,7 +10,7 @@ import sys # 添加项目根目录到 Python 路径 sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) -from rag_indexer.IndexBuilder import IndexBuilder +from rag_indexer.index_builder import IndexBuilder from rag_indexer.splitters import SplitterType async def test_index_builder(): diff --git a/rag_indexer/test/test_validate_index.py b/rag_indexer/test/test_validate_index.py index 072cd90..65017eb 100644 --- a/rag_indexer/test/test_validate_index.py +++ b/rag_indexer/test/test_validate_index.py @@ -129,7 +129,7 @@ async def check_postgres(): async def test_search(): """测试检索功能。""" - from rag_indexer.IndexBuilder import IndexBuilder, IndexBuilderConfig + from rag_indexer.index_builder import IndexBuilder, IndexBuilderConfig from rag_indexer.splitters import SplitterType print("\n" + "=" * 60)