This commit is contained in:
@@ -3,6 +3,6 @@ AI Agent 应用模块
|
||||
"""
|
||||
|
||||
from .agent import AIAgentService
|
||||
from .tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
|
||||
from .graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
|
||||
|
||||
__all__ = ["AIAgentService", "AVAILABLE_TOOLS", "TOOLS_BY_NAME"]
|
||||
|
||||
@@ -31,7 +31,7 @@ except ImportError:
|
||||
|
||||
# 本地模块
|
||||
from app.graph_builder import GraphBuilder, GraphContext
|
||||
from app.tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
|
||||
from app.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
|
||||
try:
|
||||
from app.rag import RAGPipeline
|
||||
from app.rag.tools import RAGTool
|
||||
|
||||
@@ -2,52 +2,69 @@
|
||||
RAG 检索与生成模块
|
||||
|
||||
提供在线检索与生成功能,包括:
|
||||
- 基础向量检索
|
||||
- 重排序
|
||||
- RAG-Fusion
|
||||
- Agentic RAG
|
||||
- 基础向量检索(稠密向量 / 混合检索)
|
||||
- 重排序(Cross-Encoder)
|
||||
- 多路查询改写(Multi-Query)
|
||||
- RRF 融合(Reciprocal Rank Fusion)
|
||||
- 完整的 RAG 流水线
|
||||
- Agent 工具封装
|
||||
|
||||
固定流水线:
|
||||
用户查询 → 多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档
|
||||
|
||||
示例用法:
|
||||
>>> from app.rag import RAGPipeline, search_knowledge_base
|
||||
>>> from rag_core import LlamaCppEmbedder
|
||||
>>>
|
||||
>>> embeddings = LlamaCppEmbedder()
|
||||
>>> pipeline = RAGPipeline(embeddings=embeddings)
|
||||
>>>
|
||||
>>> documents = pipeline.retrieve("戏耍貂蝉美女")
|
||||
>>> context = pipeline.format_context(documents)
|
||||
>>> from app.rag import RAGPipeline, create_rag_tool
|
||||
>>> from rag_indexer.builder import IndexBuilder, IndexBuilderConfig
|
||||
>>> from langchain_openai import ChatOpenAI
|
||||
>>>
|
||||
>>> # 获取基础检索器(如父子块检索器)
|
||||
>>> config = IndexBuilderConfig(collection_name="my_docs")
|
||||
>>> builder = IndexBuilder(config)
|
||||
>>> retriever = builder.retriever
|
||||
>>>
|
||||
>>> # 创建 LLM 和流水线
|
||||
>>> llm = ChatOpenAI(model="gpt-3.5-turbo")
|
||||
>>> pipeline = RAGPipeline(retriever=retriever, llm=llm)
|
||||
>>>
|
||||
>>> # 检索
|
||||
>>> docs = await pipeline.aretrieve("什么是 RAG?")
|
||||
>>> context = pipeline.format_context(docs)
|
||||
>>>
|
||||
>>> # 创建 Agent 工具
|
||||
>>> rag_tool = create_rag_tool(retriever=retriever, llm=llm)
|
||||
"""
|
||||
|
||||
from .retriever import (
|
||||
create_base_retriever,
|
||||
create_hybrid_retriever,
|
||||
# create_ensemble_retriever,
|
||||
create_qdrant_client,
|
||||
)
|
||||
from .reranker import CrossEncoderReranker
|
||||
from .query_transform import MultiQueryTransformer
|
||||
from .pipeline import RAGPipeline, RAGLevel
|
||||
from .tools import search_knowledge_base, search_knowledge_base_sync
|
||||
from .query_transform import MultiQueryGenerator
|
||||
from .fusion import reciprocal_rank_fusion
|
||||
from .pipeline import RAGPipeline
|
||||
from .tools import create_rag_tool, create_rag_tool_sync
|
||||
|
||||
|
||||
__all__ = [
|
||||
# 检索器
|
||||
# 检索器工厂函数
|
||||
"create_base_retriever",
|
||||
"create_hybrid_retriever",
|
||||
# "create_ensemble_retriever",
|
||||
"create_qdrant_client",
|
||||
|
||||
# 重排序器
|
||||
"CrossEncoderReranker",
|
||||
|
||||
# 查询转换器
|
||||
"MultiQueryTransformer",
|
||||
# 查询改写生成器
|
||||
"MultiQueryGenerator",
|
||||
|
||||
# 流水线
|
||||
# 融合算法
|
||||
"reciprocal_rank_fusion",
|
||||
|
||||
# 主流水线
|
||||
"RAGPipeline",
|
||||
"RAGLevel",
|
||||
|
||||
# 工具
|
||||
"search_knowledge_base",
|
||||
"search_knowledge_base_sync",
|
||||
]
|
||||
# 工具创建(供 Agent 使用)
|
||||
"create_rag_tool",
|
||||
"create_rag_tool_sync",
|
||||
]
|
||||
@@ -1,203 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
RAG 系统使用示例
|
||||
|
||||
演示如何使用 app/rag 模块进行知识检索。
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
|
||||
# 添加项目根目录到路径
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
from langchain_community.llms import VLLMOpenAI
|
||||
|
||||
|
||||
def setup_environment():
|
||||
"""设置环境变量"""
|
||||
# 设置 Qdrant 连接信息(根据实际情况修改)
|
||||
os.environ.setdefault("QDRANT_URL", "http://115.190.121.151:6333")
|
||||
# 设置 Qdrant API 密钥(根据实际情况修改)
|
||||
os.environ.setdefault("QDRANT_API_KEY", "your-api-key-here")
|
||||
# 如果需要 API 密钥,请设置 QDRANT_API_KEY
|
||||
|
||||
print("环境变量已设置")
|
||||
print(f"QDRANT_URL: {os.environ.get('QDRANT_URL')}")
|
||||
print(f"QDRANT_API_KEY: {'***' if os.environ.get('QDRANT_API_KEY') else '未设置'}")
|
||||
|
||||
|
||||
def demonstrate_basic_rag():
|
||||
"""演示基础 RAG 功能"""
|
||||
print("\n" + "="*60)
|
||||
print("演示: 基础 RAG 检索 (Level 1)")
|
||||
print("="*60)
|
||||
|
||||
# 创建嵌入模型(使用本地 LlamaCpp 模型)
|
||||
from rag_core import LlamaCppEmbedder
|
||||
embedder = LlamaCppEmbedder()
|
||||
embeddings = embedder.as_langchain_embeddings()
|
||||
|
||||
# 创建 RAG 流水线
|
||||
from app.rag import RAGPipeline, RAGLevel
|
||||
|
||||
pipeline = RAGPipeline(
|
||||
embeddings=embeddings,
|
||||
config={
|
||||
"collection_name": "rag_documents", # 你的集合名称
|
||||
"rag_level": RAGLevel.BASIC.value,
|
||||
}
|
||||
)
|
||||
|
||||
# 示例查询
|
||||
query = "吕布"
|
||||
print(f"\n查询: {query}")
|
||||
|
||||
try:
|
||||
documents = pipeline.retrieve(query)
|
||||
print(f"找到 {len(documents)} 个相关文档")
|
||||
|
||||
# 格式化上下文
|
||||
context = pipeline.format_context(documents)
|
||||
print(f"\n上下文预览:\n{context[:500]}...")
|
||||
|
||||
except Exception as e:
|
||||
print(f"检索失败: {e}")
|
||||
print("请确保 Qdrant 服务正常运行且集合存在")
|
||||
|
||||
|
||||
def demonstrate_hybrid_rag():
|
||||
"""演示混合 RAG 功能"""
|
||||
print("\n" + "="*60)
|
||||
print("演示: 混合 RAG 检索 (Level 2)")
|
||||
print("="*60)
|
||||
|
||||
from rag_core import LlamaCppEmbedder
|
||||
embedder = LlamaCppEmbedder()
|
||||
embeddings = embedder.as_langchain_embeddings()
|
||||
|
||||
from app.rag import RAGPipeline, RAGLevel
|
||||
|
||||
pipeline = RAGPipeline(
|
||||
embeddings=embeddings,
|
||||
config={
|
||||
"collection_name": "rag_documents",
|
||||
"rag_level": RAGLevel.RERANK.value,
|
||||
"rerank_top_n": 5,
|
||||
}
|
||||
)
|
||||
|
||||
query = "吕布"
|
||||
print(f"\n查询: {query}")
|
||||
|
||||
try:
|
||||
documents = pipeline.retrieve(query)
|
||||
print(f"找到 {len(documents)} 个重排序后的文档")
|
||||
|
||||
# 格式化上下文
|
||||
context = pipeline.format_context(documents)
|
||||
print(f"\n上下文预览:\n{context[:500]}...")
|
||||
|
||||
except Exception as e:
|
||||
print(f"检索失败: {e}")
|
||||
|
||||
|
||||
def demonstrate_rag_fusion():
|
||||
"""演示 RAG-Fusion 功能"""
|
||||
print("\n" + "="*60)
|
||||
print("演示: RAG-Fusion (Level 3)")
|
||||
print("="*60)
|
||||
|
||||
from rag_core import LlamaCppEmbedder
|
||||
embedder = LlamaCppEmbedder()
|
||||
embeddings = embedder.as_langchain_embeddings()
|
||||
|
||||
# 创建语言模型用于查询改写(使用 OpenAI 兼容的本地模型)
|
||||
from langchain_openai import ChatOpenAI
|
||||
llm = ChatOpenAI(
|
||||
openai_api_base="http://localhost:8000/v1",
|
||||
openai_api_key="no-key-needed",
|
||||
model="Qwen2.5-7B-Instruct", # 你的本地模型
|
||||
temperature=0.3,
|
||||
max_tokens=512,
|
||||
)
|
||||
|
||||
from app.rag import RAGPipeline, RAGLevel
|
||||
|
||||
pipeline = RAGPipeline(
|
||||
embeddings=embeddings,
|
||||
llm=llm,
|
||||
config={
|
||||
"collection_name": "rag_documents",
|
||||
"rag_level": RAGLevel.FUSION.value,
|
||||
"num_queries": 3,
|
||||
}
|
||||
)
|
||||
|
||||
query = "吕布"
|
||||
print(f"\n查询: {query}")
|
||||
|
||||
try:
|
||||
documents = pipeline.retrieve(query)
|
||||
print(f"找到 {len(documents)} 个文档 (经过多路查询改写和重排序)")
|
||||
|
||||
# 格式化上下文
|
||||
context = pipeline.format_context(documents)
|
||||
print(f"\n上下文预览:\n{context[:500]}...")
|
||||
|
||||
except Exception as e:
|
||||
print(f"检索失败: {e}")
|
||||
|
||||
|
||||
def demonstrate_agentic_rag():
|
||||
"""演示 Agentic RAG 功能"""
|
||||
print("\n" + "="*60)
|
||||
print("演示: Agentic RAG (Level 4)")
|
||||
print("="*60)
|
||||
|
||||
from app.rag import search_knowledge_base_sync
|
||||
|
||||
try:
|
||||
# 演示工具调用
|
||||
print("工具调用示例:")
|
||||
response = search_knowledge_base_sync("吕布")
|
||||
print(f"工具响应预览: {response[:200]}...")
|
||||
|
||||
except Exception as e:
|
||||
print(f"工具调用失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
print("RAG 系统演示")
|
||||
print("="*60)
|
||||
|
||||
# 设置环境
|
||||
setup_environment()
|
||||
|
||||
# 演示基础功能
|
||||
demonstrate_basic_rag()
|
||||
demonstrate_hybrid_rag()
|
||||
# demonstrate_rag_fusion() # 需要本地 LLM 服务
|
||||
# demonstrate_agentic_rag() # 需要本地 LLM 服务
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("演示完成!")
|
||||
print("="*60)
|
||||
|
||||
print("\n使用说明:")
|
||||
print("1. 确保 Qdrant 服务运行且集合已创建")
|
||||
print("2. 已使用本地 LlamaCpp 嵌入模型")
|
||||
print("3. 在 Agent 系统中导入并使用 app.rag.tools.search_knowledge_base")
|
||||
print("4. 将工具绑定到你的 Agent 模型")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
36
app/rag/fusion.py
Normal file
36
app/rag/fusion.py
Normal file
@@ -0,0 +1,36 @@
|
||||
# rag/fusion.py
|
||||
|
||||
from typing import List, Dict, Tuple
|
||||
from langchain_core.documents import Document
|
||||
|
||||
def reciprocal_rank_fusion(
|
||||
doc_lists: List[List[Document]],
|
||||
k: int = 60
|
||||
) -> List[Document]:
|
||||
"""
|
||||
对多个检索结果列表进行 RRF 融合。
|
||||
|
||||
Args:
|
||||
doc_lists: 多个检索结果列表,每个列表来自一个查询
|
||||
k: RRF 常数,通常设为 60
|
||||
|
||||
Returns:
|
||||
融合后按 RRF 得分降序排列的文档列表
|
||||
"""
|
||||
# 使用文档内容作为唯一标识(如果内容相同但 metadata 不同,视为同一文档)
|
||||
# 更好的做法是用 docstore 的 ID,这里简化处理:用内容 hash
|
||||
doc_to_score: Dict[str, float] = {}
|
||||
doc_map: Dict[str, Document] = {}
|
||||
|
||||
for docs in doc_lists:
|
||||
for rank, doc in enumerate(docs, start=1):
|
||||
# 生成唯一标识符(内容+来源组合,避免不同文件相同内容混淆)
|
||||
doc_id = f"{doc.page_content[:200]}_{doc.metadata.get('source', '')}"
|
||||
if doc_id not in doc_map:
|
||||
doc_map[doc_id] = doc
|
||||
score = doc_to_score.get(doc_id, 0.0) + 1.0 / (k + rank)
|
||||
doc_to_score[doc_id] = score
|
||||
|
||||
# 按得分排序
|
||||
sorted_ids = sorted(doc_to_score.keys(), key=lambda x: doc_to_score[x], reverse=True)
|
||||
return [doc_map[doc_id] for doc_id in sorted_ids]
|
||||
@@ -1,168 +1,92 @@
|
||||
"""
|
||||
RAG 检索流水线
|
||||
# rag/pipeline.py
|
||||
|
||||
整合基础检索、重排序和 RAG-Fusion 功能。
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Dict, Any
|
||||
import asyncio
|
||||
from typing import List, Optional
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
|
||||
from .retriever import (
|
||||
create_base_retriever,
|
||||
create_hybrid_retriever,
|
||||
create_qdrant_client,
|
||||
)
|
||||
from .reranker import CrossEncoderReranker
|
||||
from .query_transform import MultiQueryTransformer
|
||||
from rag_core import QDRANT_URL, QDRANT_API_KEY
|
||||
|
||||
|
||||
class RAGLevel(Enum):
|
||||
"""RAG 级别"""
|
||||
BASIC = "basic" # 基础向量检索
|
||||
RERANK = "rerank" # 基础检索 + 重排序
|
||||
FUSION = "fusion" # RAG-Fusion(多路查询 + RRF)
|
||||
from .retriever import create_qdrant_client # 可能不需要直接使用
|
||||
from .reranker import LLaMaCPPReranker
|
||||
from .query_transform import MultiQueryGenerator
|
||||
from .fusion import reciprocal_rank_fusion
|
||||
|
||||
|
||||
class RAGPipeline:
|
||||
"""RAG 检索流水线"""
|
||||
"""
|
||||
固定流程的 RAG 检索流水线:
|
||||
多路改写 → 并行检索 → RRF融合 → 重排序 → 返回父文档
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embeddings,
|
||||
llm: Optional[BaseLanguageModel] = None,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
retriever, # 基础检索器(应返回父文档,例如 ParentDocumentRetriever 实例)
|
||||
llm: BaseLanguageModel,
|
||||
num_queries: int = 3,
|
||||
rerank_top_n: int = 5,
|
||||
rerank_model: str = "BAAI/bge-reranker-base",
|
||||
):
|
||||
"""
|
||||
初始化 RAG 流水线
|
||||
|
||||
Args:
|
||||
embeddings: 嵌入模型
|
||||
llm: 语言模型(用于 RAG-Fusion)
|
||||
config: 配置参数
|
||||
retriever: 基础检索器对象,需实现 ainvoke(query) 异步方法
|
||||
llm: 用于生成多路查询的语言模型
|
||||
num_queries: 生成的查询变体数量
|
||||
rerank_top_n: 最终返回的文档数量
|
||||
rerank_model: 重排序模型名称
|
||||
"""
|
||||
self.embeddings = embeddings
|
||||
self.retriever = retriever
|
||||
self.llm = llm
|
||||
self.config = config or {}
|
||||
self.num_queries = num_queries
|
||||
self.rerank_top_n = rerank_top_n
|
||||
|
||||
self.collection_name = self.config.get("collection_name", "rag_documents")
|
||||
self.rag_level = self.config.get("rag_level", RAGLevel.RERANK.value)
|
||||
self.num_queries = self.config.get("num_queries", 3)
|
||||
self.rerank_top_n = self.config.get("rerank_top_n", 5)
|
||||
|
||||
# 初始化基础检索器
|
||||
self.base_retriever = create_base_retriever(
|
||||
collection_name=self.collection_name,
|
||||
embeddings=self.embeddings,
|
||||
search_kwargs={"k": 20}, # 召回 20 条
|
||||
)
|
||||
|
||||
# 初始化重排序器
|
||||
try:
|
||||
self.reranker = CrossEncoderReranker(top_n=self.rerank_top_n)
|
||||
except Exception as e:
|
||||
print(f"警告: 无法创建重排序器,将使用基础检索。错误: {e}")
|
||||
self.reranker = None
|
||||
|
||||
# 根据 RAG 级别创建检索器
|
||||
self.retriever = self._create_retriever()
|
||||
|
||||
def _create_retriever(self):
|
||||
"""根据 RAG 级别创建检索器"""
|
||||
if self.rag_level == RAGLevel.BASIC.value:
|
||||
return self.base_retriever
|
||||
|
||||
# 基础检索 + 重排序
|
||||
def rerank_retriever(query):
|
||||
documents = self.base_retriever.invoke(query)
|
||||
if self.reranker:
|
||||
return self.reranker.compress_documents(documents, query)
|
||||
else:
|
||||
return documents[:self.rerank_top_n]
|
||||
|
||||
if self.rag_level == RAGLevel.RERANK.value:
|
||||
return SimpleRetriever(rerank_retriever)
|
||||
|
||||
# RAG-Fusion
|
||||
if self.rag_level == RAGLevel.FUSION.value:
|
||||
if not self.llm:
|
||||
raise ValueError("RAG-Fusion 需要提供 llm 参数")
|
||||
|
||||
# 创建多路查询检索器
|
||||
transformer = MultiQueryTransformer(
|
||||
llm=self.llm,
|
||||
num_queries=self.num_queries
|
||||
# 初始化组件
|
||||
self.query_generator = MultiQueryGenerator(llm=llm, num_queries=num_queries)
|
||||
self.reranker = LLaMaCPPReranker(
|
||||
base_url="http://127.0.0.1:8083",
|
||||
top_n=rerank_top_n,
|
||||
api_key="huang1998"
|
||||
)
|
||||
multi_query_retriever = transformer.create_multi_query_retriever(
|
||||
base_retriever=SimpleRetriever(rerank_retriever)
|
||||
)
|
||||
|
||||
return multi_query_retriever
|
||||
|
||||
return SimpleRetriever(rerank_retriever)
|
||||
|
||||
def retrieve(self, query: str) -> List[Document]:
|
||||
"""
|
||||
执行检索
|
||||
|
||||
Args:
|
||||
query: 查询字符串
|
||||
|
||||
Returns:
|
||||
相关文档列表
|
||||
"""
|
||||
return self.retriever.invoke(query)
|
||||
|
||||
async def aretrieve(self, query: str) -> List[Document]:
|
||||
"""
|
||||
异步执行检索
|
||||
|
||||
Args:
|
||||
query: 查询字符串
|
||||
|
||||
Returns:
|
||||
相关文档列表
|
||||
异步执行完整检索流程
|
||||
"""
|
||||
return await self.retriever.ainvoke(query)
|
||||
# Step 1: 生成多路查询
|
||||
queries = await self.query_generator.agenerate(query)
|
||||
# 包含原始查询,确保至少有一条
|
||||
if query not in queries:
|
||||
queries.insert(0, query)
|
||||
else:
|
||||
# 如果原始查询已在列表中,将其移至首位
|
||||
queries.remove(query)
|
||||
queries.insert(0, query)
|
||||
|
||||
# Step 2: 并行检索(每个查询获取文档列表)
|
||||
tasks = [self.retriever.ainvoke(q) for q in queries]
|
||||
doc_lists = await asyncio.gather(*tasks)
|
||||
|
||||
# Step 3: RRF 融合
|
||||
fused_docs = reciprocal_rank_fusion(doc_lists)
|
||||
|
||||
# Step 4: 重排序
|
||||
if self.reranker.model is not None:
|
||||
final_docs = self.reranker.compress_documents(fused_docs, query)
|
||||
else:
|
||||
# 若重排序器不可用,直接返回融合后的前 N 条
|
||||
final_docs = fused_docs[:self.rerank_top_n]
|
||||
|
||||
return final_docs
|
||||
|
||||
def retrieve(self, query: str) -> List[Document]:
|
||||
"""同步检索入口(内部调用异步方法)"""
|
||||
return asyncio.run(self.aretrieve(query))
|
||||
|
||||
def format_context(self, documents: List[Document]) -> str:
|
||||
"""
|
||||
格式化上下文
|
||||
|
||||
Args:
|
||||
documents: 文档列表
|
||||
|
||||
Returns:
|
||||
格式化后的上下文字符串
|
||||
"""
|
||||
"""将文档列表格式化为上下文字符串"""
|
||||
if not documents:
|
||||
return ""
|
||||
|
||||
context_parts = []
|
||||
parts = []
|
||||
for i, doc in enumerate(documents, 1):
|
||||
content = doc.page_content
|
||||
metadata = doc.metadata or {}
|
||||
source = metadata.get("source", "未知来源")
|
||||
|
||||
part = f"【资料 {i}】\n"
|
||||
part += f"来源: {source}\n"
|
||||
part += f"内容: {content}\n"
|
||||
part += "---\n"
|
||||
context_parts.append(part)
|
||||
|
||||
return "".join(context_parts)
|
||||
|
||||
|
||||
class SimpleRetriever:
|
||||
"""简单检索器包装类"""
|
||||
|
||||
def __init__(self, retrieve_func):
|
||||
self.retrieve_func = retrieve_func
|
||||
|
||||
def invoke(self, query):
|
||||
return self.retrieve_func(query)
|
||||
|
||||
async def ainvoke(self, query):
|
||||
return self.retrieve_func(query)
|
||||
source = doc.metadata.get("source", "未知来源")
|
||||
parts.append(f"【资料 {i}】来源:{source}\n{doc.page_content}\n---\n")
|
||||
return "\n".join(parts)
|
||||
@@ -1,62 +1,43 @@
|
||||
"""
|
||||
查询转换器模块
|
||||
# rag/query_transform.py
|
||||
|
||||
实现多路查询改写功能,用于 RAG-Fusion。
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from typing import List
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
# from langchain.retrievers.multi_query import MultiQueryRetriever
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
|
||||
MULTI_QUERY_PROMPT = PromptTemplate.from_template(
|
||||
"""你是一个专业的查询改写助手。你的任务是将用户的问题改写成 {num_queries} 个不同的版本。
|
||||
这些版本应该从不同的角度、使用不同的关键词来表达相同或相关的意图。
|
||||
|
||||
class MultiQueryTransformer:
|
||||
"""多路查询改写器,用于 RAG-Fusion。"""
|
||||
原始问题: {question}
|
||||
|
||||
请生成 {num_queries} 个不同版本的查询,每个版本一行。
|
||||
确保每个版本都是独立、完整的查询语句。
|
||||
|
||||
生成 {num_queries} 个查询:"""
|
||||
)
|
||||
|
||||
class MultiQueryGenerator:
|
||||
"""多路查询生成器(不依赖 LangChain 的 MultiQueryRetriever)"""
|
||||
|
||||
def __init__(self, llm: BaseLanguageModel, num_queries: int = 3):
|
||||
"""
|
||||
初始化多路查询改写器。
|
||||
|
||||
Args:
|
||||
llm: 语言模型实例
|
||||
num_queries: 生成的查询数量
|
||||
"""
|
||||
self.llm = llm
|
||||
self.num_queries = num_queries
|
||||
self.prompt = MULTI_QUERY_PROMPT
|
||||
|
||||
def create_multi_query_retriever(self, base_retriever):
|
||||
"""
|
||||
创建多路查询检索器。
|
||||
|
||||
Args:
|
||||
base_retriever: 基础检索器
|
||||
|
||||
Returns:
|
||||
MultiQueryRetriever 实例
|
||||
"""
|
||||
# 由于当前 LangChain 版本不支持 MultiQueryRetriever,暂时返回基础检索器
|
||||
# retriever = MultiQueryRetriever.from_llm(
|
||||
# retriever=base_retriever,
|
||||
# llm=self.llm,
|
||||
# include_original=True
|
||||
# )
|
||||
#
|
||||
# # 自定义提示词
|
||||
# retriever.llm_chain.prompt = PromptTemplate.from_template(
|
||||
# "你是一个专业的查询改写助手。你的任务是将用户的问题改写成 {num_queries} 个不同的版本。\n"
|
||||
# "这些版本应该从不同的角度、使用不同的关键词来表达相同或相关的意图。\n\n"
|
||||
# "原始问题: {question}\n\n"
|
||||
# "请生成 {num_queries} 个不同版本的查询,每个版本一行。\n"
|
||||
# "确保每个版本都是独立、完整的查询语句。\n\n"
|
||||
# "生成 {num_queries} 个查询:"
|
||||
# )
|
||||
#
|
||||
# # 修改调用参数以包含 num_queries
|
||||
# original_ainvoke = retriever.llm_chain.ainvoke
|
||||
# async def new_ainvoke(input_dict):
|
||||
# input_dict["num_queries"] = self.num_queries
|
||||
# return await original_ainvoke(input_dict)
|
||||
# retriever.llm_chain.ainvoke = new_ainvoke
|
||||
#
|
||||
# return retriever
|
||||
return base_retriever
|
||||
def generate(self, query: str) -> List[str]:
|
||||
"""同步生成多个查询变体"""
|
||||
prompt_str = self.prompt.format(num_queries=self.num_queries, question=query)
|
||||
response = self.llm.invoke(prompt_str)
|
||||
# 处理响应内容,按行分割并去除空行和首尾空白
|
||||
lines = response.content.strip().split('\n')
|
||||
queries = [line.strip() for line in lines if line.strip()]
|
||||
# 确保至少返回原始查询
|
||||
return queries[:self.num_queries] if queries else [query]
|
||||
|
||||
async def agenerate(self, query: str) -> List[str]:
|
||||
"""异步生成多个查询变体"""
|
||||
prompt_str = self.prompt.format(num_queries=self.num_queries, question=query)
|
||||
response = await self.llm.ainvoke(prompt_str)
|
||||
lines = response.content.strip().split('\n')
|
||||
queries = [line.strip() for line in lines if line.strip()]
|
||||
return queries[:self.num_queries] if queries else [query]
|
||||
@@ -1,35 +1,34 @@
|
||||
"""
|
||||
重排序器模块
|
||||
|
||||
使用 Cross-Encoder 模型对检索结果进行重排序,提高检索精度。
|
||||
重排序器模块 (适配版)
|
||||
使用远程 llama.cpp 服务 (兼容 OpenAI Rerank API) 替代本地 Cross-Encoder
|
||||
"""
|
||||
|
||||
import requests
|
||||
from typing import List
|
||||
from langchain_core.documents import Document
|
||||
|
||||
class LLaMaCPPReranker:
|
||||
"""使用远程 llama.cpp 服务对检索结果重排序。"""
|
||||
|
||||
class CrossEncoderReranker:
|
||||
"""使用 Cross-Encoder 对检索结果重排序。"""
|
||||
|
||||
def __init__(self, model_name: str = "BAAI/bge-reranker-base", top_n: int = 5):
|
||||
def __init__(self,
|
||||
base_url: str = "http://127.0.0.1:8083",
|
||||
top_n: int = 5,
|
||||
api_key: str = "huang1998", # 你设置的 LLAMA_ARG_API_KEY
|
||||
timeout: int = 60):
|
||||
"""
|
||||
初始化重排序器
|
||||
初始化远程重排序器
|
||||
|
||||
Args:
|
||||
model_name: 预训练模型名称
|
||||
top_n: 返回前 N 个结果
|
||||
base_url: llama.cpp 服务的地址和端口。
|
||||
top_n: 返回前 N 个结果。
|
||||
api_key: 在容器中设置的 API 密钥。
|
||||
timeout: 请求超时时间(秒)。
|
||||
"""
|
||||
self.model_name = model_name
|
||||
self.base_url = base_url.rstrip('/')
|
||||
self.top_n = top_n
|
||||
self.model = None
|
||||
self.api_key = api_key
|
||||
self.timeout = timeout
|
||||
self.endpoint = f"{self.base_url}/v1/rerank"
|
||||
|
||||
# 尝试加载 Cross-Encoder 模型
|
||||
try:
|
||||
from sentence_transformers import CrossEncoder
|
||||
self.model = CrossEncoder(model_name)
|
||||
except Exception as e:
|
||||
print(f"警告: 无法加载 Cross-Encoder 模型 {model_name},将使用简单排序作为回退方案。错误: {e}")
|
||||
|
||||
def compress_documents(
|
||||
self, documents: List[Document], query: str
|
||||
) -> List[Document]:
|
||||
@@ -45,21 +44,32 @@ class CrossEncoderReranker:
|
||||
"""
|
||||
if not documents:
|
||||
return []
|
||||
|
||||
# 如果模型加载失败,返回前 top_n 个文档
|
||||
if self.model is None:
|
||||
return documents[:self.top_n]
|
||||
|
||||
# 使用 Cross-Encoder 进行重排序
|
||||
|
||||
# 准备请求体
|
||||
# 根据 llama.cpp 的 OpenAI 兼容性,文档是一个字符串列表
|
||||
payload = {
|
||||
"model": "bge-reranker-v2-m3",
|
||||
"query": query,
|
||||
"documents": [doc.page_content for doc in documents],
|
||||
"top_n": self.top_n
|
||||
}
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}"
|
||||
}
|
||||
|
||||
try:
|
||||
pairs = [[query, doc.page_content] for doc in documents]
|
||||
scores = self.model.predict(pairs)
|
||||
response = requests.post(self.endpoint, json=payload, headers=headers, timeout=self.timeout)
|
||||
response.raise_for_status() # 检查请求是否成功
|
||||
results = response.json()
|
||||
|
||||
# 解析返回结果
|
||||
# 返回格式: {"results": [{"index": 0, "document": "...", "relevance_score": 0.8}, ...]}
|
||||
# 按相关性得分降序排列
|
||||
sorted_indices = [item["index"] for item in results["results"]]
|
||||
sorted_docs = [documents[idx] for idx in sorted_indices]
|
||||
return sorted_docs
|
||||
|
||||
# 按分数降序排序
|
||||
scored_docs = sorted(
|
||||
zip(documents, scores), key=lambda x: x[1], reverse=True
|
||||
)
|
||||
return [doc for doc, _ in scored_docs[:self.top_n]]
|
||||
except Exception as e:
|
||||
print(f"警告: 重排序过程出错,将使用原始排序。错误: {e}")
|
||||
return documents[:self.top_n]
|
||||
print(f"警告: 远程重排序过程出错,将使用原始排序。错误: {e}")
|
||||
return documents[:self.top_n]
|
||||
@@ -1,39 +1,83 @@
|
||||
"""
|
||||
Qdrant 向量检索器
|
||||
Qdrant 向量检索器模块
|
||||
|
||||
提供基础向量检索、混合检索(Dense + BM25)功能。
|
||||
提供基于 Qdrant 的基础向量检索和混合检索(Dense + Sparse)功能。
|
||||
|
||||
核心原理:
|
||||
- 基础检索:将查询文本转换为向量,在 Qdrant 中进行近似最近邻(ANN)搜索,
|
||||
使用余弦相似度返回最相似的 k 个文档。
|
||||
- 混合检索:结合稠密向量检索(语义相似)和 BM25 稀疏向量检索(关键词匹配),
|
||||
通过加权或分数融合提高召回精度。
|
||||
|
||||
使用示例:
|
||||
>>> from rag_core import LlamaCppEmbedder
|
||||
>>> embedder = LlamaCppEmbedder()
|
||||
>>> embeddings = embedder.as_langchain_embeddings()
|
||||
>>>
|
||||
>>> # 创建基础检索器
|
||||
>>> retriever = create_base_retriever(
|
||||
... collection_name="my_docs",
|
||||
... embeddings=embeddings,
|
||||
... search_kwargs={"k": 10}
|
||||
... )
|
||||
>>>
|
||||
>>> # 执行检索
|
||||
>>> docs = retriever.invoke("什么是 RAG?")
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Optional
|
||||
from langchain_qdrant import QdrantVectorStore
|
||||
from langchain.embeddings.base import Embeddings
|
||||
# from langchain.retrievers import EnsembleRetriever
|
||||
from typing import Optional, Dict, Any
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||
from langchain_qdrant import QdrantVectorStore
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
|
||||
from rag_core import QDRANT_URL, QDRANT_API_KEY
|
||||
|
||||
# 模块级常量
|
||||
DEFAULT_SEARCH_K = 20
|
||||
DEFAULT_SCORE_THRESHOLD = 0.3
|
||||
|
||||
|
||||
def create_qdrant_client(
|
||||
url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
timeout: int = 30,
|
||||
) -> QdrantClient:
|
||||
"""
|
||||
创建 Qdrant 客户端
|
||||
创建并返回一个配置好的 Qdrant 客户端。
|
||||
|
||||
优先使用传入参数,若未提供则回退到环境变量 QDRANT_URL 和 QDRANT_API_KEY。
|
||||
|
||||
Args:
|
||||
url: Qdrant 服务地址,默认从环境变量 QDRANT_URL 读取
|
||||
api_key: API 密钥,默认从环境变量 QDRANT_API_KEY 读取
|
||||
url: Qdrant 服务地址,例如 "http://localhost:6333"。
|
||||
默认从环境变量 QDRANT_URL 读取。
|
||||
api_key: API 密钥(若 Qdrant 启用了认证)。
|
||||
默认从环境变量 QDRANT_API_KEY 读取。
|
||||
timeout: 请求超时时间(秒),默认 30 秒。
|
||||
|
||||
Returns:
|
||||
QdrantClient 实例
|
||||
配置好的 QdrantClient 实例。
|
||||
|
||||
Raises:
|
||||
ValueError: 如果 url 为空且环境变量也未设置。
|
||||
"""
|
||||
url = url or QDRANT_URL
|
||||
api_key = api_key or QDRANT_API_KEY
|
||||
effective_url = url or QDRANT_URL
|
||||
if not effective_url:
|
||||
raise ValueError(
|
||||
"Qdrant URL 未提供,请设置参数 url 或环境变量 QDRANT_URL"
|
||||
)
|
||||
|
||||
client_args = {"url": url}
|
||||
if api_key:
|
||||
client_args["api_key"] = api_key
|
||||
effective_api_key = api_key or QDRANT_API_KEY
|
||||
|
||||
return QdrantClient(**client_args)
|
||||
client_kwargs = {
|
||||
"url": effective_url,
|
||||
"timeout": timeout,
|
||||
}
|
||||
if effective_api_key:
|
||||
client_kwargs["api_key"] = effective_api_key
|
||||
|
||||
return QdrantClient(**client_kwargs)
|
||||
|
||||
|
||||
def create_base_retriever(
|
||||
@@ -41,33 +85,57 @@ def create_base_retriever(
|
||||
embeddings: Embeddings,
|
||||
search_kwargs: Optional[Dict[str, Any]] = None,
|
||||
client: Optional[QdrantClient] = None,
|
||||
) -> QdrantVectorStore:
|
||||
) -> BaseRetriever:
|
||||
"""
|
||||
创建基础向量检索器
|
||||
创建基础向量检索器(仅稠密向量检索)。
|
||||
|
||||
该检索器使用嵌入模型将查询转为向量,在 Qdrant 集合中执行 ANN 搜索,
|
||||
返回语义上最相似的文档块。
|
||||
|
||||
Args:
|
||||
collection_name: Qdrant 集合名称
|
||||
embeddings: 嵌入模型
|
||||
search_kwargs: 搜索参数,默认 {"k": 20}
|
||||
client: Qdrant 客户端,如果为 None 则自动创建
|
||||
collection_name: Qdrant 集合名称(需预先创建并索引)。
|
||||
embeddings: LangChain 兼容的嵌入模型实例。
|
||||
search_kwargs: 搜索参数,可包含:
|
||||
- k (int): 返回的文档数量,默认 20。
|
||||
- score_threshold (float): 相似度阈值,仅返回高于此分数的文档。
|
||||
- filter (dict): Qdrant 过滤条件。
|
||||
若为 None,则使用默认值 {"k": 20}。
|
||||
client: 可选的 Qdrant 客户端实例。若未提供,将自动创建。
|
||||
|
||||
Returns:
|
||||
QdrantVectorStore 检索器实例
|
||||
"""
|
||||
search_kwargs = search_kwargs or {"k": 20}
|
||||
BaseRetriever 实例,可直接调用 .invoke(query) 或 .ainvoke(query) 检索。
|
||||
|
||||
# 创建 Qdrant 客户端
|
||||
Raises:
|
||||
ValueError: 如果集合不存在或嵌入模型无效。
|
||||
"""
|
||||
# 合并默认搜索参数
|
||||
merged_search_kwargs = {"k": DEFAULT_SEARCH_K}
|
||||
if search_kwargs:
|
||||
merged_search_kwargs.update(search_kwargs)
|
||||
|
||||
# 创建或复用 Qdrant 客户端
|
||||
if client is None:
|
||||
client = create_qdrant_client()
|
||||
|
||||
# 使用 QdrantVectorStore 创建向量存储
|
||||
# 验证集合是否存在(可选,便于提前发现问题)
|
||||
try:
|
||||
client.get_collection(collection_name)
|
||||
except UnexpectedResponse as e:
|
||||
if e.status_code == 404:
|
||||
raise ValueError(
|
||||
f"Qdrant 集合 '{collection_name}' 不存在,请先创建并索引文档。"
|
||||
)
|
||||
raise
|
||||
|
||||
# 构建向量存储
|
||||
vector_store = QdrantVectorStore(
|
||||
client=client,
|
||||
collection_name=collection_name,
|
||||
embedding=embeddings,
|
||||
)
|
||||
|
||||
return vector_store.as_retriever(search_kwargs=search_kwargs)
|
||||
# 返回检索器
|
||||
return vector_store.as_retriever(search_kwargs=merged_search_kwargs)
|
||||
|
||||
|
||||
def create_hybrid_retriever(
|
||||
@@ -75,64 +143,57 @@ def create_hybrid_retriever(
|
||||
embeddings: Embeddings,
|
||||
dense_k: int = 10,
|
||||
sparse_k: int = 10,
|
||||
score_threshold: Optional[float] = DEFAULT_SCORE_THRESHOLD,
|
||||
client: Optional[QdrantClient] = None,
|
||||
) -> QdrantVectorStore:
|
||||
) -> BaseRetriever:
|
||||
"""
|
||||
创建混合检索器(Dense Vector + BM25)
|
||||
创建混合检索器(稠密向量 + BM25 稀疏向量)。
|
||||
|
||||
混合检索结合了语义相似度(Dense)和关键词匹配(Sparse),
|
||||
能够更好地处理专有名词、精确匹配等场景。
|
||||
|
||||
注意:此功能要求 Qdrant 集合已配置稀疏向量字段并生成了 BM25 索引。
|
||||
若集合未配置稀疏向量,将回退到纯稠密检索(不会报错,但检索效果降级)。
|
||||
|
||||
Args:
|
||||
collection_name: Qdrant 集合名称
|
||||
embeddings: 嵌入模型
|
||||
dense_k: 向量检索返回数量
|
||||
sparse_k: BM25 检索返回数量
|
||||
client: Qdrant 客户端
|
||||
collection_name: Qdrant 集合名称。
|
||||
embeddings: 嵌入模型(用于稠密向量)。
|
||||
dense_k: 稠密向量检索返回数量,默认 10。
|
||||
sparse_k: 稀疏向量检索返回数量,默认 10。
|
||||
score_threshold: 相似度阈值,默认 0.3。
|
||||
client: 可选的 Qdrant 客户端实例。
|
||||
|
||||
Returns:
|
||||
混合检索器
|
||||
BaseRetriever 实例,配置了混合搜索参数。
|
||||
"""
|
||||
# 创建 Qdrant 客户端
|
||||
if client is None:
|
||||
client = create_qdrant_client()
|
||||
|
||||
# 使用 QdrantVectorStore 创建向量存储
|
||||
vector_store = QdrantVectorStore(
|
||||
client=client,
|
||||
collection_name=collection_name,
|
||||
embedding=embeddings,
|
||||
)
|
||||
total_k = dense_k + sparse_k
|
||||
|
||||
search_kwargs = {
|
||||
"k": dense_k + sparse_k,
|
||||
"score_threshold": 0.3,
|
||||
"k": total_k,
|
||||
}
|
||||
if score_threshold is not None:
|
||||
search_kwargs["score_threshold"] = score_threshold
|
||||
|
||||
return vector_store.as_retriever(search_kwargs=search_kwargs)
|
||||
# 复用基础检索器创建逻辑,只需调整搜索参数
|
||||
return create_base_retriever(
|
||||
collection_name=collection_name,
|
||||
embeddings=embeddings,
|
||||
search_kwargs=search_kwargs,
|
||||
client=client,
|
||||
)
|
||||
|
||||
|
||||
# def create_ensemble_retriever(
|
||||
# retrievers: List[Any],
|
||||
# weights: Optional[List[float]] = None,
|
||||
# c: int = 60,
|
||||
# ) -> EnsembleRetriever:
|
||||
# """
|
||||
# 创建集成检索器,支持倒数排名融合 (RRF)
|
||||
#
|
||||
# Args:
|
||||
# retrievers: 检索器列表
|
||||
# weights: 检索器权重
|
||||
# c: RRF 常数(通常为60)
|
||||
#
|
||||
# Returns:
|
||||
# 集成检索器
|
||||
# """
|
||||
# if weights is None:
|
||||
# weights = [1.0 / len(retrievers)] * len(retrievers)
|
||||
#
|
||||
# ensemble = EnsembleRetriever(
|
||||
# retrievers=retrievers,
|
||||
# weights=weights,
|
||||
# c=c,
|
||||
# search_type="rrf",
|
||||
# )
|
||||
#
|
||||
# return ensemble
|
||||
# 可选:提供异步友好的辅助函数
|
||||
async def acreate_base_retriever(
|
||||
collection_name: str,
|
||||
embeddings: Embeddings,
|
||||
search_kwargs: Optional[Dict[str, Any]] = None,
|
||||
client: Optional[QdrantClient] = None,
|
||||
) -> BaseRetriever:
|
||||
"""
|
||||
异步创建基础向量检索器(与同步版本功能相同)。
|
||||
|
||||
适用于需要异步初始化的场景(例如在 FastAPI 启动事件中)。
|
||||
"""
|
||||
# 由于 QdrantVectorStore 初始化本身是同步的,这里直接调用同步版本即可
|
||||
return create_base_retriever(collection_name, embeddings, search_kwargs, client)
|
||||
153
app/rag/test.py
Normal file
153
app/rag/test.py
Normal file
@@ -0,0 +1,153 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
RAG 系统使用示例(重构版)
|
||||
|
||||
演示:
|
||||
1. 使用 IndexBuilder 获取父子块检索器
|
||||
2. 创建固定流程的 RAGPipeline(多路改写 → RRF融合 → 重排序 → 返回父文档)
|
||||
3. 将流水线封装为 LangChain 工具,供 Agent 调用
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 加载环境变量(Qdrant URL、PostgreSQL 连接等)
|
||||
load_dotenv()
|
||||
|
||||
# 添加项目根目录到路径
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||
|
||||
from rag_indexer.index_builder import IndexBuilder, IndexBuilderConfig
|
||||
from rag_indexer.splitters import SplitterType
|
||||
from rag.pipeline import RAGPipeline
|
||||
from rag.tools import create_rag_tool
|
||||
from pydantic import SecretStr
|
||||
# 使用本地 LLM(通过 OpenAI 兼容接口)
|
||||
from langchain_openai import ChatOpenAI
|
||||
from rag_core.retriever_factory import create_parent_retriever
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
def create_llm():
|
||||
"""创建本地 vLLM 服务 LLM"""
|
||||
vllm_base_url = os.getenv(
|
||||
"VLLM_BASE_URL",
|
||||
"http://127.0.0.1:8081/v1"
|
||||
)
|
||||
|
||||
return ChatOpenAI(
|
||||
base_url=vllm_base_url,
|
||||
api_key=SecretStr(os.getenv("LLAMACPP_API_KEY", "token-abc123")),
|
||||
model="gemma-4-E2B-it",
|
||||
timeout=60.0, # 请求超时时间(秒)
|
||||
max_retries=2, # 失败后自动重试次数
|
||||
streaming=True, # 确保开启流式输出
|
||||
)
|
||||
|
||||
async def demonstrate_full_pipeline():
|
||||
"""
|
||||
完整流水线演示:
|
||||
- 从 IndexBuilder 获取 ParentDocumentRetriever
|
||||
- 创建 RAGPipeline
|
||||
- 执行检索并打印结果
|
||||
"""
|
||||
print("=" * 60)
|
||||
print("演示:固定流程 RAG 检索(多路改写 + RRF + 重排序 + 父文档)")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
retriever = retriever = create_parent_retriever(collection_name="my_docs", search_k=5)
|
||||
|
||||
if retriever is None:
|
||||
print("错误:检索器未初始化,请确保索引已构建。")
|
||||
return
|
||||
|
||||
# 3. 创建 LLM 用于查询改写
|
||||
llm = create_llm()
|
||||
|
||||
# 4. 创建 RAGPipeline(固定流程)
|
||||
pipeline = RAGPipeline(
|
||||
retriever=retriever,
|
||||
llm=llm,
|
||||
num_queries=3, # 生成 3 个查询变体
|
||||
rerank_top_n=5, # 最终返回 5 个父文档
|
||||
)
|
||||
|
||||
# 5. 执行检索
|
||||
query = "打虎英雄是谁?"
|
||||
print(f"\n查询: {query}")
|
||||
print("-" * 40)
|
||||
|
||||
try:
|
||||
documents = await pipeline.aretrieve(query)
|
||||
print(f"返回 {len(documents)} 个父文档\n")
|
||||
|
||||
# 打印结果预览
|
||||
for i, doc in enumerate(documents, 1):
|
||||
content_preview = doc.page_content.replace("\n", " ")[:150]
|
||||
source = doc.metadata.get("source", "未知来源")
|
||||
print(f"{i}. 【来源:{source}】")
|
||||
print(f" {content_preview}...\n")
|
||||
|
||||
# 可选:格式化完整上下文
|
||||
# context = pipeline.format_context(documents)
|
||||
# print(context)
|
||||
|
||||
except Exception as e:
|
||||
print(f"检索失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
async def demonstrate_tool_creation():
|
||||
"""
|
||||
演示创建 RAG 工具(供 Agent 使用)
|
||||
"""
|
||||
print("\n" + "=" * 60)
|
||||
print("演示:创建 RAG 工具(供 LangGraph Agent 调用)")
|
||||
print("=" * 60)
|
||||
|
||||
# 1. 获取检索器(同上)
|
||||
config = IndexBuilderConfig(
|
||||
collection_name="rag_documents",
|
||||
splitter_type=SplitterType.PARENT_CHILD,
|
||||
)
|
||||
retriever = retriever = create_parent_retriever(collection_name="rag_documents", search_k=5)
|
||||
|
||||
|
||||
# 2. 创建 LLM
|
||||
llm = create_llm()
|
||||
|
||||
# 3. 创建工具
|
||||
rag_tool = create_rag_tool(
|
||||
retriever=retriever,
|
||||
llm=llm,
|
||||
num_queries=3,
|
||||
rerank_top_n=5,
|
||||
collection_name="rag_documents",
|
||||
)
|
||||
|
||||
print(f"工具名称: {rag_tool.name}")
|
||||
print(f"工具描述: {rag_tool.description[:100]}...")
|
||||
|
||||
# 4. 模拟 Agent 调用工具
|
||||
query = "请告诉我 RAG 系统的核心组件有哪些?"
|
||||
print(f"\n模拟调用: {query}")
|
||||
print("-" * 40)
|
||||
|
||||
result = await rag_tool.ainvoke({"query": query})
|
||||
print(result[:800] + "..." if len(result) > 800 else result)
|
||||
|
||||
|
||||
async def main():
|
||||
await demonstrate_full_pipeline()
|
||||
await demonstrate_tool_creation()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
173
app/rag/tools.py
173
app/rag/tools.py
@@ -2,88 +2,115 @@
|
||||
RAG 工具模块
|
||||
|
||||
将检索功能封装为 LangChain Tool,供 Agent 调用。
|
||||
采用固定流水线:多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档。
|
||||
"""
|
||||
|
||||
from typing import Optional, Callable
|
||||
from langchain_core.tools import tool
|
||||
from rag_core import LlamaCppEmbedder, QDRANT_URL, QDRANT_API_KEY
|
||||
from .pipeline import RAGPipeline, RAGLevel
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
|
||||
from .pipeline import RAGPipeline
|
||||
|
||||
|
||||
@tool
|
||||
async def search_knowledge_base(query: str, rag_level: str = "rerank") -> str:
|
||||
"""在知识库中搜索与查询相关的文档片段。
|
||||
|
||||
适用于事实性问题、背景知识查询。
|
||||
|
||||
Args:
|
||||
query: 查询字符串
|
||||
rag_level: 检索级别,可选值:basic(基础向量检索)、rerank(基础检索+重排序)、fusion(RAG-Fusion)
|
||||
|
||||
Returns:
|
||||
检索到的相关文档内容
|
||||
def create_rag_tool(
|
||||
retriever: BaseRetriever,
|
||||
llm: BaseLanguageModel,
|
||||
num_queries: int = 3,
|
||||
rerank_top_n: int = 5,
|
||||
collection_name: str = "rag_documents",
|
||||
) -> Callable:
|
||||
"""
|
||||
# 初始化嵌入模型
|
||||
embedder = LlamaCppEmbedder()
|
||||
embeddings = embedder.as_langchain_embeddings()
|
||||
|
||||
# 创建 RAG 流水线
|
||||
pipeline = RAGPipeline(
|
||||
embeddings=embeddings,
|
||||
config={
|
||||
"rag_level": rag_level,
|
||||
"collection_name": "rag_documents",
|
||||
"rerank_top_n": 5,
|
||||
}
|
||||
)
|
||||
|
||||
# 执行检索
|
||||
try:
|
||||
documents = await pipeline.aretrieve(query)
|
||||
if not documents:
|
||||
return "未找到相关信息。"
|
||||
|
||||
# 格式化结果
|
||||
context = pipeline.format_context(documents)
|
||||
return context
|
||||
except Exception as e:
|
||||
return f"检索过程中发生错误: {str(e)}"
|
||||
创建一个配置好的 RAG 检索工具(异步)。
|
||||
|
||||
|
||||
@tool
|
||||
def search_knowledge_base_sync(query: str, rag_level: str = "rerank") -> str:
|
||||
"""同步版本的知识库搜索工具。
|
||||
|
||||
适用于事实性问题、背景知识查询。
|
||||
|
||||
Args:
|
||||
query: 查询字符串
|
||||
rag_level: 检索级别,可选值:basic(基础向量检索)、rerank(基础检索+重排序)、fusion(RAG-Fusion)
|
||||
|
||||
retriever: 基础检索器(例如 ParentDocumentRetriever 实例)
|
||||
llm: 用于多路查询改写的语言模型
|
||||
num_queries: 生成查询变体数量
|
||||
rerank_top_n: 最终返回的文档数量
|
||||
collection_name: 集合名称(仅用于日志/描述)
|
||||
|
||||
Returns:
|
||||
检索到的相关文档内容
|
||||
LangChain Tool 可调用对象(异步)
|
||||
"""
|
||||
# 初始化嵌入模型
|
||||
embedder = LlamaCppEmbedder()
|
||||
embeddings = embedder.as_langchain_embeddings()
|
||||
|
||||
# 创建 RAG 流水线
|
||||
# 初始化流水线(所有组件一次创建,后续复用)
|
||||
pipeline = RAGPipeline(
|
||||
embeddings=embeddings,
|
||||
config={
|
||||
"rag_level": rag_level,
|
||||
"collection_name": "rag_documents",
|
||||
"rerank_top_n": 5,
|
||||
}
|
||||
retriever=retriever,
|
||||
llm=llm,
|
||||
num_queries=num_queries,
|
||||
rerank_top_n=rerank_top_n,
|
||||
)
|
||||
|
||||
# 执行检索
|
||||
try:
|
||||
documents = pipeline.retrieve(query)
|
||||
if not documents:
|
||||
return "未找到相关信息。"
|
||||
|
||||
# 格式化结果
|
||||
context = pipeline.format_context(documents)
|
||||
return context
|
||||
except Exception as e:
|
||||
return f"检索过程中发生错误: {str(e)}"
|
||||
|
||||
@tool
|
||||
async def search_knowledge_base(query: str) -> str:
|
||||
"""在知识库中搜索与查询相关的文档片段。
|
||||
|
||||
该工具会:
|
||||
1. 将用户问题改写成多个不同角度的查询
|
||||
2. 并行检索每个查询的相关父文档
|
||||
3. 使用倒数排名融合(RRF)合并结果
|
||||
4. 用 Cross-Encoder 重排序模型精选最相关的片段
|
||||
|
||||
适用于需要精确、全面答案的事实性问题或背景知识查询。
|
||||
|
||||
Args:
|
||||
query: 用户提出的问题或查询字符串
|
||||
|
||||
Returns:
|
||||
格式化后的相关文档内容,若无结果则返回提示信息。
|
||||
"""
|
||||
try:
|
||||
documents = await pipeline.aretrieve(query)
|
||||
if not documents:
|
||||
return f"在知识库 '{collection_name}' 中未找到与 '{query}' 相关的信息。"
|
||||
|
||||
context = pipeline.format_context(documents)
|
||||
return context
|
||||
except Exception as e:
|
||||
return f"检索过程中发生错误: {str(e)}"
|
||||
|
||||
return search_knowledge_base
|
||||
|
||||
|
||||
def create_rag_tool_sync(
|
||||
retriever: BaseRetriever,
|
||||
llm: BaseLanguageModel,
|
||||
num_queries: int = 3,
|
||||
rerank_top_n: int = 5,
|
||||
collection_name: str = "rag_documents",
|
||||
) -> Callable:
|
||||
"""
|
||||
创建一个配置好的 RAG 检索工具(同步版本,用于不支持异步的旧版 Agent)。
|
||||
|
||||
参数同 create_rag_tool。
|
||||
"""
|
||||
pipeline = RAGPipeline(
|
||||
retriever=retriever,
|
||||
llm=llm,
|
||||
num_queries=num_queries,
|
||||
rerank_top_n=rerank_top_n,
|
||||
)
|
||||
|
||||
@tool
|
||||
def search_knowledge_base_sync(query: str) -> str:
|
||||
"""在知识库中搜索与查询相关的文档片段(同步版本)。
|
||||
|
||||
功能与异步版本相同:多路改写 → RRF融合 → 重排序 → 返回父文档。
|
||||
|
||||
Args:
|
||||
query: 用户提出的问题或查询字符串
|
||||
|
||||
Returns:
|
||||
格式化后的相关文档内容。
|
||||
"""
|
||||
try:
|
||||
documents = pipeline.retrieve(query) # 内部调用异步方法并等待
|
||||
if not documents:
|
||||
return f"在知识库 '{collection_name}' 中未找到与 '{query}' 相关的信息。"
|
||||
|
||||
context = pipeline.format_context(documents)
|
||||
return context
|
||||
except Exception as e:
|
||||
return f"检索过程中发生错误: {str(e)}"
|
||||
|
||||
return search_knowledge_base_sync
|
||||
@@ -7,6 +7,8 @@ RAG Core - 公共 RAG 组件包
|
||||
from .embedders import LlamaCppEmbedder
|
||||
from .vector_store import QdrantVectorStore, QDRANT_URL, QDRANT_API_KEY
|
||||
from .store import PostgresDocStore, create_docstore
|
||||
from .retriever_factory import create_parent_retriever
|
||||
|
||||
|
||||
__all__ = [
|
||||
"LlamaCppEmbedder",
|
||||
@@ -15,4 +17,5 @@ __all__ = [
|
||||
"QDRANT_API_KEY",
|
||||
"PostgresDocStore",
|
||||
"create_docstore",
|
||||
"create_parent_retriever",
|
||||
]
|
||||
|
||||
24
rag_core/client.py
Normal file
24
rag_core/client.py
Normal file
@@ -0,0 +1,24 @@
|
||||
# rag_core/client.py
|
||||
import os
|
||||
from typing import Optional
|
||||
from qdrant_client import QdrantClient
|
||||
|
||||
QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333")
|
||||
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
|
||||
|
||||
def create_qdrant_client(
|
||||
url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
timeout: int = 120, # 索引构建需要较长超时
|
||||
) -> QdrantClient:
|
||||
effective_url = url or QDRANT_URL
|
||||
effective_api_key = api_key or QDRANT_API_KEY
|
||||
|
||||
if not effective_url:
|
||||
raise ValueError("Qdrant URL 未配置")
|
||||
|
||||
client_kwargs = {"url": effective_url, "timeout": timeout}
|
||||
if effective_api_key:
|
||||
client_kwargs["api_key"] = effective_api_key
|
||||
|
||||
return QdrantClient(**client_kwargs)
|
||||
67
rag_core/retriever_factory.py
Normal file
67
rag_core/retriever_factory.py
Normal file
@@ -0,0 +1,67 @@
|
||||
# rag_core/retriever_factory.py
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_classic.retrievers import ParentDocumentRetriever
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
from rag_indexer.splitters import SplitterType, get_splitter
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import List, Union, Optional, Any, Dict, Tuple
|
||||
from httpx import RemoteProtocolError
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.stores import BaseStore
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
|
||||
from langchain_classic.retrievers import ParentDocumentRetriever
|
||||
|
||||
from rag_core import LlamaCppEmbedder, QdrantVectorStore, create_docstore
|
||||
|
||||
|
||||
def create_parent_retriever(
|
||||
collection_name: str = "rag_documents",
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
parent_splitter: Optional[TextSplitter] = None,
|
||||
child_splitter: Optional[TextSplitter] = None,
|
||||
docstore: Optional[BaseStore] = None,
|
||||
search_k: int = 5,
|
||||
# 若未传入切分器,则用以下参数创建默认切分器
|
||||
parent_chunk_size: int = 1000,
|
||||
parent_chunk_overlap: int = 100,
|
||||
child_chunk_size: int = 200,
|
||||
child_chunk_overlap: int = 20,
|
||||
) -> ParentDocumentRetriever:
|
||||
# 嵌入模型
|
||||
if embeddings is None:
|
||||
embedder = LlamaCppEmbedder()
|
||||
embeddings = embedder.as_langchain_embeddings()
|
||||
|
||||
# 向量存储(只读)
|
||||
vector_store = QdrantVectorStore(
|
||||
collection_name=collection_name,
|
||||
embeddings=embeddings,
|
||||
)
|
||||
|
||||
# 切分器(若未提供则创建默认)
|
||||
if parent_splitter is None:
|
||||
parent_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=parent_chunk_size,
|
||||
chunk_overlap=parent_chunk_overlap,
|
||||
)
|
||||
if child_splitter is None:
|
||||
child_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=child_chunk_size,
|
||||
chunk_overlap=child_chunk_overlap,
|
||||
)
|
||||
|
||||
# 文档存储
|
||||
if docstore is None:
|
||||
docstore, _ = create_docstore() # 从环境变量读取连接
|
||||
|
||||
return ParentDocumentRetriever(
|
||||
vectorstore=vector_store.get_langchain_vectorstore(),
|
||||
docstore=docstore,
|
||||
child_splitter=child_splitter,
|
||||
parent_splitter=parent_splitter,
|
||||
search_kwargs={"k": search_k},
|
||||
)
|
||||
@@ -10,6 +10,7 @@ from langchain_core.documents import Document
|
||||
from langchain_qdrant import QdrantVectorStore as LangchainQdrantVS
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.http.models import Distance, VectorParams
|
||||
from .client import create_qdrant_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -44,14 +45,8 @@ class QdrantVectorStore:
|
||||
)
|
||||
|
||||
def get_client(self) -> QdrantClient:
|
||||
"""懒加载客户端,每次获取时确保连接可用。"""
|
||||
if self._client is None:
|
||||
self._client = QdrantClient(
|
||||
url=QDRANT_URL,
|
||||
api_key=QDRANT_API_KEY,
|
||||
timeout=120,
|
||||
http2=False,
|
||||
)
|
||||
self._client = create_qdrant_client(timeout=120)
|
||||
return self._client
|
||||
|
||||
def refresh_client(self):
|
||||
|
||||
@@ -23,7 +23,7 @@ Offline RAG Indexer module.
|
||||
>>> await builder.build_from_file("document.pdf")
|
||||
"""
|
||||
|
||||
from .IndexBuilder import IndexBuilder, IndexBuilderConfig, DocstoreConfig
|
||||
from .index_builder import IndexBuilder, IndexBuilderConfig, DocstoreConfig
|
||||
from .loaders import DocumentLoader
|
||||
from .splitters import SplitterType, get_splitter
|
||||
|
||||
@@ -39,7 +39,7 @@ __version__ = "2.0.0"
|
||||
|
||||
__all__ = [
|
||||
# 核心构建器与配置
|
||||
"IndexBuilder",
|
||||
"index_builder",
|
||||
"IndexBuilderConfig",
|
||||
"DocstoreConfig",
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from rag_indexer.IndexBuilder import IndexBuilder, IndexBuilderConfig
|
||||
from rag_indexer.index_builder import IndexBuilder, IndexBuilderConfig
|
||||
from rag_indexer.splitters import SplitterType
|
||||
|
||||
logging.basicConfig(
|
||||
|
||||
@@ -19,7 +19,8 @@ from langchain_classic.retrievers import ParentDocumentRetriever
|
||||
|
||||
from .loaders import DocumentLoader
|
||||
from .splitters import SplitterType, get_splitter, SemanticChunkerAdapter
|
||||
from rag_core import LlamaCppEmbedder, QdrantVectorStore, create_docstore
|
||||
from rag_core import LlamaCppEmbedder, QdrantVectorStore, create_docstore, create_parent_retriever
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -113,43 +114,40 @@ class IndexBuilder:
|
||||
logger.info("使用单一 %s 切分器", self.config.splitter_type.value)
|
||||
|
||||
def _init_parent_child_mode(self) -> None:
|
||||
"""父子块切分模式,初始化父块/子块切分器、文档存储和检索器。"""
|
||||
cfg = self.config
|
||||
|
||||
# 父块切分器(始终使用递归切分)
|
||||
# 父块切分器(索引构建需要,必须保留)
|
||||
self.parent_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=cfg.parent_chunk_size,
|
||||
chunk_overlap=cfg.parent_chunk_overlap,
|
||||
)
|
||||
|
||||
# 子块切分器
|
||||
# 子块切分器(索引构建需要)
|
||||
if cfg.child_splitter_type == SplitterType.SEMANTIC:
|
||||
self.child_splitter = get_splitter(
|
||||
SplitterType.SEMANTIC,
|
||||
embeddings=self.embeddings,
|
||||
**cfg.extra_splitter_kwargs
|
||||
)
|
||||
logger.info("子块使用语义切分器")
|
||||
else:
|
||||
self.child_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=cfg.child_chunk_size,
|
||||
chunk_overlap=cfg.child_chunk_overlap,
|
||||
)
|
||||
logger.info("子块使用递归切分器,块大小=%d,重叠=%d",
|
||||
cfg.child_chunk_size, cfg.child_chunk_overlap)
|
||||
|
||||
# 初始化文档存储(用于父块)
|
||||
# 文档存储
|
||||
self.docstore = self._create_or_use_docstore()
|
||||
|
||||
# 创建检索器
|
||||
self.retriever = ParentDocumentRetriever(
|
||||
vectorstore=self.vector_store.get_langchain_vectorstore(),
|
||||
docstore=self.docstore,
|
||||
child_splitter=self.child_splitter, # type: ignore[arg-type]
|
||||
# 使用工厂函数创建检索器,避免重复代码
|
||||
self.retriever = create_parent_retriever(
|
||||
collection_name=cfg.collection_name,
|
||||
embeddings=self.embeddings,
|
||||
parent_splitter=self.parent_splitter,
|
||||
search_kwargs={"k": cfg.search_k},
|
||||
child_splitter=self.child_splitter,
|
||||
docstore=self.docstore,
|
||||
search_k=cfg.search_k,
|
||||
)
|
||||
logger.info("ParentDocumentRetriever 初始化完成,父块大小=%d", cfg.parent_chunk_size)
|
||||
logger.info("ParentDocumentRetriever 初始化完成")
|
||||
|
||||
def _create_or_use_docstore(self) -> BaseStore:
|
||||
"""创建或获取文档存储实例。"""
|
||||
@@ -10,7 +10,7 @@ import sys
|
||||
# 添加项目根目录到 Python 路径
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
|
||||
from rag_indexer.IndexBuilder import IndexBuilder
|
||||
from rag_indexer.index_builder import IndexBuilder
|
||||
from rag_indexer.splitters import SplitterType
|
||||
|
||||
async def test_index_builder():
|
||||
|
||||
@@ -129,7 +129,7 @@ async def check_postgres():
|
||||
|
||||
async def test_search():
|
||||
"""测试检索功能。"""
|
||||
from rag_indexer.IndexBuilder import IndexBuilder, IndexBuilderConfig
|
||||
from rag_indexer.index_builder import IndexBuilder, IndexBuilderConfig
|
||||
from rag_indexer.splitters import SplitterType
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
|
||||
Reference in New Issue
Block a user