feat: RAG混合检索系统完整实现 + 启动脚本修复
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 5m4s
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 5m4s
- 实现了稠密+稀疏混合检索,使用 Qdrant 原生 RRF 融合 - 修复了 retriever.py 的 BaseRetriever 继承和稀疏向量包装问题 - 修复了 pipeline.py 的 Optional 导入问题 - 添加了稀疏 embedder 的缓存配置 - 简化了 vector_store.py,移除不必要的逻辑 - 修复了 start.sh 的 PROJECT_DIR 硬编码和端口配置问题 - 完善了 RAG 检索的测试文件
This commit is contained in:
@@ -9,7 +9,7 @@ RAG 检索流水线模块
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
|
|
||||||
|
|||||||
@@ -6,19 +6,18 @@ Qdrant 混合检索器模块
|
|||||||
- 父子文档混合检索(先检索子文档,再返回父文档)
|
- 父子文档混合检索(先检索子文档,再返回父文档)
|
||||||
|
|
||||||
核心原理:
|
核心原理:
|
||||||
- 使用 Qdrant 原生 Fusion API (RRF) 做分数融合
|
- 使用 Qdrant Universal Query API (query_points)
|
||||||
- 同时使用稠密向量(语义)和稀疏向量(BM25 关键词)
|
- 使用 Prefetch 并行检索多个源
|
||||||
|
- 使用 RRF 分数融合
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Dict, Any, Optional, List
|
from typing import Dict, Any, Optional, List
|
||||||
from qdrant_client import QdrantClient
|
from qdrant_client import QdrantClient, models
|
||||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||||
from qdrant_client.http.models import (
|
|
||||||
SearchRequest, Fusion, FusionProtocol, NamedVector, NamedSparseVector
|
|
||||||
)
|
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.retrievers import BaseRetriever, RetrieverOutput
|
from langchain_core.retrievers import BaseRetriever, RetrieverOutput
|
||||||
|
from pydantic import Field, PrivateAttr
|
||||||
|
|
||||||
from rag_core import QdrantVectorStore, get_sparse_embedder, create_docstore
|
from rag_core import QdrantVectorStore, get_sparse_embedder, create_docstore
|
||||||
from rag_core.client import create_qdrant_client as create_core_qdrant_client
|
from rag_core.client import create_qdrant_client as create_core_qdrant_client
|
||||||
@@ -35,8 +34,14 @@ class HybridRetriever(BaseRetriever):
|
|||||||
"""
|
"""
|
||||||
混合检索器:稠密向量 + BM25 稀疏向量 RRF 分数融合
|
混合检索器:稠密向量 + BM25 稀疏向量 RRF 分数融合
|
||||||
|
|
||||||
直接使用 Qdrant 原生 Fusion API,性能最优。
|
使用 Qdrant Universal Query API (query_points)
|
||||||
"""
|
"""
|
||||||
|
collection_name: str = Field(description="Qdrant 集合名称")
|
||||||
|
search_k: int = Field(default=DEFAULT_SEARCH_K, description="检索返回结果数")
|
||||||
|
|
||||||
|
_vector_store: Any = PrivateAttr()
|
||||||
|
_client: Any = PrivateAttr()
|
||||||
|
_sparse_embedder: Any = PrivateAttr()
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -50,57 +55,59 @@ class HybridRetriever(BaseRetriever):
|
|||||||
vector_store: QdrantVectorStore 实例
|
vector_store: QdrantVectorStore 实例
|
||||||
search_k: 检索返回结果数
|
search_k: 检索返回结果数
|
||||||
"""
|
"""
|
||||||
self.collection_name = collection_name
|
super().__init__(
|
||||||
self.vector_store = vector_store
|
collection_name=collection_name,
|
||||||
self.search_k = search_k
|
search_k=search_k
|
||||||
self.client = vector_store.get_qdrant_client()
|
)
|
||||||
self.sparse_embedder = get_sparse_embedder()
|
self._vector_store = vector_store
|
||||||
|
self._client = vector_store.get_qdrant_client()
|
||||||
|
self._sparse_embedder = get_sparse_embedder()
|
||||||
|
|
||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
self, query: str, *, run_manager: Optional[Any] = None
|
self, query: str, **kwargs
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
"""
|
"""
|
||||||
同步检索相关文档
|
同步检索相关文档
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query: 查询字符串
|
query: 查询字符串
|
||||||
run_manager: LangChain 运行管理器(可选)
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
相关文档列表
|
相关文档列表
|
||||||
"""
|
"""
|
||||||
# 生成双向量
|
# 1. 生成双向量
|
||||||
dense_query = self.vector_store.embeddings.embed_query(query)
|
dense_query = self._vector_store.embeddings.embed_query(query)
|
||||||
sparse_query = self.sparse_embedder.embed_query(query)
|
sparse_query = self._sparse_embedder.embed_query(query)
|
||||||
|
sparse_vec = models.SparseVector(
|
||||||
# 构建双检索请求
|
indices=sparse_query["indices"],
|
||||||
searches = [
|
values=sparse_query["values"]
|
||||||
# 稠密检索
|
|
||||||
SearchRequest(
|
|
||||||
vector=NamedVector(name="dense", vector=dense_query),
|
|
||||||
limit=self.search_k,
|
|
||||||
with_payload=True
|
|
||||||
),
|
|
||||||
# 稀疏检索
|
|
||||||
SearchRequest(
|
|
||||||
vector=NamedSparseVector(name="sparse", vector=sparse_query),
|
|
||||||
limit=self.search_k,
|
|
||||||
with_payload=True
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
# RRF 分数融合
|
|
||||||
fused_results = self.client.fusion(
|
|
||||||
collection_name=self.collection_name,
|
|
||||||
requests=searches,
|
|
||||||
fusion=Fusion(fusion=FusionProtocol.RRF)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 转换为 Document 格式
|
# 2. 使用官方的 query_points API(推荐方式)
|
||||||
|
response = self._client.query_points(
|
||||||
|
collection_name=self.collection_name,
|
||||||
|
prefetch=[ # 并行预取多个检索源
|
||||||
|
models.Prefetch(
|
||||||
|
query=dense_query,
|
||||||
|
using="dense", # 使用稠密向量进行语义搜索
|
||||||
|
limit=self.search_k
|
||||||
|
),
|
||||||
|
models.Prefetch(
|
||||||
|
query=sparse_vec,
|
||||||
|
using="sparse", # 使用稀疏向量进行关键词搜索
|
||||||
|
limit=self.search_k
|
||||||
|
)
|
||||||
|
],
|
||||||
|
query=models.FusionQuery(fusion=models.Fusion.RRF), # 指定融合算法为 RRF
|
||||||
|
limit=self.search_k, # 最终返回的结果数量
|
||||||
|
with_payload=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. 转换结果
|
||||||
results = []
|
results = []
|
||||||
for point in fused_results.points:
|
for point in response.points:
|
||||||
doc = Document(
|
doc = Document(
|
||||||
page_content=point.payload.pop("text", ""),
|
page_content=point.payload.pop("page_content", point.payload.pop("text", "")),
|
||||||
metadata=point.payload
|
metadata=point.payload
|
||||||
)
|
)
|
||||||
results.append(doc)
|
results.append(doc)
|
||||||
@@ -109,11 +116,11 @@ class HybridRetriever(BaseRetriever):
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
async def _aget_relevant_documents(
|
||||||
self, query: str, *, run_manager: Optional[Any] = None
|
self, query: str, **kwargs
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
"""异步检索(当前调用同步版本)"""
|
"""异步检索(当前调用同步版本)"""
|
||||||
# Qdrant 客户端没有原生 async,这里用同步版本
|
# Qdrant 客户端没有原生 async,这里用同步版本
|
||||||
return self._get_relevant_documents(query, run_manager=run_manager)
|
return self._get_relevant_documents(query, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class ParentHybridRetriever(BaseRetriever):
|
class ParentHybridRetriever(BaseRetriever):
|
||||||
@@ -125,6 +132,14 @@ class ParentHybridRetriever(BaseRetriever):
|
|||||||
3. 去重并返回父文档
|
3. 去重并返回父文档
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
collection_name: str = Field(description="Qdrant 集合名称")
|
||||||
|
search_k: int = Field(default=DEFAULT_PARENT_SEARCH_K, description="检索返回结果数")
|
||||||
|
|
||||||
|
_vector_store: Any = PrivateAttr()
|
||||||
|
_client: Any = PrivateAttr()
|
||||||
|
_sparse_embedder: Any = PrivateAttr()
|
||||||
|
_docstore: Any = PrivateAttr()
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
collection_name: str,
|
collection_name: str,
|
||||||
@@ -136,58 +151,62 @@ class ParentHybridRetriever(BaseRetriever):
|
|||||||
Args:
|
Args:
|
||||||
collection_name: Qdrant 集合名称
|
collection_name: Qdrant 集合名称
|
||||||
vector_store: QdrantVectorStore 实例
|
vector_store: QdrantVectorStore 实例
|
||||||
search_k: 最终返回的父文档数
|
search_k: 最终返回的父文档数量
|
||||||
docstore: 文档存储(如果父文档在 PostgreSQL),可选
|
docstore: 文档存储(如果父文档在 PostgreSQL),可选
|
||||||
"""
|
"""
|
||||||
self.collection_name = collection_name
|
super().__init__(
|
||||||
self.vector_store = vector_store
|
collection_name=collection_name,
|
||||||
self.search_k = search_k
|
search_k=search_k
|
||||||
self.client = vector_store.get_qdrant_client()
|
)
|
||||||
self.sparse_embedder = get_sparse_embedder()
|
self._vector_store = vector_store
|
||||||
self.docstore = docstore
|
self._client = vector_store.get_qdrant_client()
|
||||||
|
self._sparse_embedder = get_sparse_embedder()
|
||||||
|
self._docstore = docstore
|
||||||
|
|
||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
self, query: str, *, run_manager: Optional[Any] = None
|
self, query: str, **kwargs
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
"""
|
"""
|
||||||
同步检索相关父文档
|
同步检索相关父文档
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query: 查询字符串
|
query: 查询字符串
|
||||||
run_manager: LangChain 运行管理器(可选)
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
相关父文档列表
|
相关父文档列表
|
||||||
"""
|
"""
|
||||||
# 1. 生成查询双向量
|
# 1. 生成查询双向量
|
||||||
dense_query = self.vector_store.embeddings.embed_query(query)
|
dense_query = self._vector_store.embeddings.embed_query(query)
|
||||||
sparse_query = self.sparse_embedder.embed_query(query)
|
sparse_query = self._sparse_embedder.embed_query(query)
|
||||||
|
sparse_vec = models.SparseVector(
|
||||||
|
indices=sparse_query["indices"],
|
||||||
|
values=sparse_query["values"]
|
||||||
|
)
|
||||||
|
|
||||||
# 2. 多取一些子文档,避免去重后数量不足
|
# 2. 多取一些子文档,避免去重后数量不足
|
||||||
search_limit = self.search_k * 2
|
search_limit = self.search_k * 2
|
||||||
searches = [
|
|
||||||
# 稠密检索
|
|
||||||
SearchRequest(
|
|
||||||
vector=NamedVector(name="dense", vector=dense_query),
|
|
||||||
limit=search_limit,
|
|
||||||
with_payload=True
|
|
||||||
),
|
|
||||||
# 稀疏检索
|
|
||||||
SearchRequest(
|
|
||||||
vector=NamedSparseVector(name="sparse", vector=sparse_query),
|
|
||||||
limit=search_limit,
|
|
||||||
with_payload=True
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
# 3. RRF 分数融合,拿到子文档命中结果
|
# 3. 使用 query_points API 进行混合检索
|
||||||
fused_results = self.client.fusion(
|
response = self._client.query_points(
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
requests=searches,
|
prefetch=[
|
||||||
fusion=Fusion(fusion=FusionProtocol.RRF)
|
models.Prefetch(
|
||||||
|
query=dense_query,
|
||||||
|
using="dense",
|
||||||
|
limit=search_limit
|
||||||
|
),
|
||||||
|
models.Prefetch(
|
||||||
|
query=sparse_vec,
|
||||||
|
using="sparse",
|
||||||
|
limit=search_limit
|
||||||
|
)
|
||||||
|
],
|
||||||
|
query=models.FusionQuery(fusion=models.Fusion.RRF),
|
||||||
|
limit=search_limit,
|
||||||
|
with_payload=True
|
||||||
)
|
)
|
||||||
|
|
||||||
if not fused_results.points:
|
if not response.points:
|
||||||
debug("混合检索未找到任何文档")
|
debug("混合检索未找到任何文档")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@@ -196,8 +215,10 @@ class ParentHybridRetriever(BaseRetriever):
|
|||||||
parent_ids = set()
|
parent_ids = set()
|
||||||
child_point_map = {} # 保存子文档点用于降级
|
child_point_map = {} # 保存子文档点用于降级
|
||||||
|
|
||||||
for point in fused_results.points:
|
for point in response.points:
|
||||||
parent_id = point.payload.get("parent_id", point.id)
|
# 先复制 payload,避免修改原始对象
|
||||||
|
payload_copy = point.payload.copy()
|
||||||
|
parent_id = payload_copy.get("parent_id", point.id)
|
||||||
score = point.score
|
score = point.score
|
||||||
|
|
||||||
# 同一个 parent_id 只保留最高得分
|
# 同一个 parent_id 只保留最高得分
|
||||||
@@ -207,12 +228,12 @@ class ParentHybridRetriever(BaseRetriever):
|
|||||||
child_point_map[parent_id] = point
|
child_point_map[parent_id] = point
|
||||||
|
|
||||||
# 5. 批量查询父文档
|
# 5. 批量查询父文档
|
||||||
# 首先尝试从 Qdrant 直接查询(因为父文档可能也存在 Qdrant 中)
|
# 首先尝试从 Qdrant 直接查询(因为父文档可能也在 Qdrant 中)
|
||||||
parent_docs = []
|
parent_docs = []
|
||||||
found_parent_ids = set()
|
found_parent_ids = set()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
parent_points = self.client.retrieve(
|
parent_points = self._client.retrieve(
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
ids=list(parent_ids),
|
ids=list(parent_ids),
|
||||||
with_payload=True
|
with_payload=True
|
||||||
@@ -220,9 +241,10 @@ class ParentHybridRetriever(BaseRetriever):
|
|||||||
|
|
||||||
# 处理找到的父文档
|
# 处理找到的父文档
|
||||||
for point in parent_points:
|
for point in parent_points:
|
||||||
|
payload_copy = point.payload.copy()
|
||||||
doc = Document(
|
doc = Document(
|
||||||
page_content=point.payload.pop("text", ""),
|
page_content=payload_copy.pop("page_content", payload_copy.pop("text", "")),
|
||||||
metadata=point.payload
|
metadata=payload_copy
|
||||||
)
|
)
|
||||||
parent_docs.append(doc)
|
parent_docs.append(doc)
|
||||||
found_parent_ids.add(point.id)
|
found_parent_ids.add(point.id)
|
||||||
@@ -231,10 +253,10 @@ class ParentHybridRetriever(BaseRetriever):
|
|||||||
warning(f"从 Qdrant 查询父文档失败: {e}")
|
warning(f"从 Qdrant 查询父文档失败: {e}")
|
||||||
|
|
||||||
# 6. 如果有 docstore,尝试从 docstore 查询剩余的父文档
|
# 6. 如果有 docstore,尝试从 docstore 查询剩余的父文档
|
||||||
if self.docstore and len(found_parent_ids) < len(parent_ids):
|
if self._docstore and len(found_parent_ids) < len(parent_ids):
|
||||||
missing_parent_ids = parent_ids - found_parent_ids
|
missing_parent_ids = parent_ids - found_parent_ids
|
||||||
try:
|
try:
|
||||||
docstore_docs = self.docstore.mget(missing_parent_ids)
|
docstore_docs = self._docstore.mget(missing_parent_ids)
|
||||||
for doc_id, doc in zip(missing_parent_ids, docstore_docs):
|
for doc_id, doc in zip(missing_parent_ids, docstore_docs):
|
||||||
if doc is not None:
|
if doc is not None:
|
||||||
parent_docs.append(doc)
|
parent_docs.append(doc)
|
||||||
@@ -249,9 +271,10 @@ class ParentHybridRetriever(BaseRetriever):
|
|||||||
for parent_id in missing_parent_ids:
|
for parent_id in missing_parent_ids:
|
||||||
child_point = child_point_map.get(parent_id)
|
child_point = child_point_map.get(parent_id)
|
||||||
if child_point:
|
if child_point:
|
||||||
|
payload_copy = child_point.payload.copy()
|
||||||
doc = Document(
|
doc = Document(
|
||||||
page_content=child_point.payload.pop("text", ""),
|
page_content=payload_copy.pop("page_content", payload_copy.pop("text", "")),
|
||||||
metadata=child_point.payload
|
metadata=payload_copy
|
||||||
)
|
)
|
||||||
parent_docs.append(doc)
|
parent_docs.append(doc)
|
||||||
|
|
||||||
@@ -268,10 +291,10 @@ class ParentHybridRetriever(BaseRetriever):
|
|||||||
return final_docs
|
return final_docs
|
||||||
|
|
||||||
async def _aget_relevant_documents(
|
async def _aget_relevant_documents(
|
||||||
self, query: str, *, run_manager: Optional[Any] = None
|
self, query: str, **kwargs
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
"""异步检索(当前调用同步版本)"""
|
"""异步检索(当前调用同步版本)"""
|
||||||
return self._get_relevant_documents(query, run_manager=run_manager)
|
return self._get_relevant_documents(query, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def create_hybrid_retriever(
|
def create_hybrid_retriever(
|
||||||
@@ -333,7 +356,7 @@ def create_parent_hybrid_retriever(
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
collection_name: Qdrant 集合名称
|
collection_name: Qdrant 集合名称
|
||||||
search_k: 最终返回的父文档数
|
search_k: 最终返回的父文档数量
|
||||||
embeddings: 可选的嵌入模型实例
|
embeddings: 可选的嵌入模型实例
|
||||||
use_docstore: 是否使用 PostgreSQL docstore 存储父文档
|
use_docstore: 是否使用 PostgreSQL docstore 存储父文档
|
||||||
|
|
||||||
@@ -375,5 +398,31 @@ def create_parent_hybrid_retriever(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_base_retriever(
|
||||||
|
collection_name: str,
|
||||||
|
search_k: int = DEFAULT_SEARCH_K,
|
||||||
|
embeddings: Optional[Embeddings] = None,
|
||||||
|
) -> BaseRetriever:
|
||||||
|
"""
|
||||||
|
创建基础稠密检索器(向后兼容)。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
collection_name: Qdrant 集合名称
|
||||||
|
search_k: 检索返回结果数
|
||||||
|
embeddings: 可选的嵌入模型实例
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LangChain 的 BaseRetriever 实例
|
||||||
|
"""
|
||||||
|
# 默认使用统一嵌入服务
|
||||||
|
if embeddings is None:
|
||||||
|
embeddings = get_embedding_service()
|
||||||
|
|
||||||
|
vector_store = QdrantVectorStore(collection_name=collection_name, embeddings=embeddings)
|
||||||
|
|
||||||
|
info(f"✅ Qdrant 基础稠密检索器初始化成功(search_k={search_k})")
|
||||||
|
return vector_store.as_langchain_vectorstore().as_retriever(k=search_k)
|
||||||
|
|
||||||
|
|
||||||
# 别名:默认就是父子文档混合检索
|
# 别名:默认就是父子文档混合检索
|
||||||
create_retriever = create_parent_hybrid_retriever
|
create_retriever = create_parent_hybrid_retriever
|
||||||
|
|||||||
@@ -54,3 +54,5 @@ DOCSTORE_URI = _get_str("DOCSTORE_URI") or DB_URI
|
|||||||
|
|
||||||
# ========== 其他配置 ==========
|
# ========== 其他配置 ==========
|
||||||
# 可以在此添加其他 RAG Core 专用的配置项
|
# 可以在此添加其他 RAG Core 专用的配置项
|
||||||
|
# 稀疏模型缓存路径
|
||||||
|
FASTEMBED_CACHE_PATH = _get_str("FASTEMBED_CACHE_PATH") or "./models/fastembed_cache"
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ BM25 稀疏嵌入器
|
|||||||
"""
|
"""
|
||||||
from typing import List
|
from typing import List
|
||||||
from fastembed.sparse.sparse_text_embedding import SparseTextEmbedding
|
from fastembed.sparse.sparse_text_embedding import SparseTextEmbedding
|
||||||
from app.config import FASTEMBED_CACHE_PATH
|
from .config import FASTEMBED_CACHE_PATH
|
||||||
|
|
||||||
class BM25SparseEmbedder:
|
class BM25SparseEmbedder:
|
||||||
"""BM25 稀疏嵌入包装器,与现有嵌入器风格统一"""
|
"""BM25 稀疏嵌入包装器,与现有嵌入器风格统一"""
|
||||||
|
|||||||
@@ -14,8 +14,7 @@ from langchain_core.embeddings import Embeddings
|
|||||||
from langchain_qdrant import QdrantVectorStore as LangchainQdrantVS
|
from langchain_qdrant import QdrantVectorStore as LangchainQdrantVS
|
||||||
from qdrant_client import QdrantClient
|
from qdrant_client import QdrantClient
|
||||||
from qdrant_client.http.models import (
|
from qdrant_client.http.models import (
|
||||||
Distance, VectorParams, SparseVectorParams, SparseIndexParams,
|
Distance, VectorParams, SparseVectorParams, PointStruct
|
||||||
SparseIndexType, PointStruct, NamedSparseVector, NamedVector
|
|
||||||
)
|
)
|
||||||
from httpx import RemoteProtocolError
|
from httpx import RemoteProtocolError
|
||||||
from qdrant_client.http.exceptions import ResponseHandlingException
|
from qdrant_client.http.exceptions import ResponseHandlingException
|
||||||
@@ -61,6 +60,7 @@ class QdrantVectorStore:
|
|||||||
client=self.get_client(),
|
client=self.get_client(),
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
embedding=self.embeddings,
|
embedding=self.embeddings,
|
||||||
|
vector_name="dense",
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_client(self) -> QdrantClient:
|
def get_client(self) -> QdrantClient:
|
||||||
@@ -134,19 +134,13 @@ class QdrantVectorStore:
|
|||||||
vectors_config = {
|
vectors_config = {
|
||||||
"dense": VectorParams(
|
"dense": VectorParams(
|
||||||
size=vector_size,
|
size=vector_size,
|
||||||
distance=Distance.COSINE,
|
distance=Distance.COSINE
|
||||||
optional=True
|
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
# 稀疏向量配置
|
# 稀疏向量配置(简化版,不使用特殊索引类型)
|
||||||
sparse_vectors_config = {
|
sparse_vectors_config = {
|
||||||
"sparse": SparseVectorParams(
|
"sparse": SparseVectorParams()
|
||||||
index=SparseIndexParams(
|
|
||||||
type=SparseIndexType.MUTABLE
|
|
||||||
),
|
|
||||||
optional=True
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
client.create_collection(
|
client.create_collection(
|
||||||
@@ -197,10 +191,7 @@ class QdrantVectorStore:
|
|||||||
# 构造双向量
|
# 构造双向量
|
||||||
named_vectors = {
|
named_vectors = {
|
||||||
"dense": dense_vectors[j],
|
"dense": dense_vectors[j],
|
||||||
"sparse": NamedSparseVector(
|
"sparse": sparse_vectors[j]
|
||||||
name="sparse",
|
|
||||||
vector=sparse_vectors[j]
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
points.append(PointStruct(
|
points.append(PointStruct(
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ def get_input_path() -> Path:
|
|||||||
if len(sys.argv) > 1:
|
if len(sys.argv) > 1:
|
||||||
return Path(sys.argv[1])
|
return Path(sys.argv[1])
|
||||||
# 默认测试路径(可按需修改)
|
# 默认测试路径(可按需修改)
|
||||||
return Path("data/user_docs/a.txt")
|
return Path("data/user_docs/doublestory.txt")
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
|
|||||||
@@ -121,7 +121,7 @@ class IndexBuilder:
|
|||||||
# 初始化向量存储(自动支持稠密+稀疏混合检索)
|
# 初始化向量存储(自动支持稠密+稀疏混合检索)
|
||||||
self.vector_store = QdrantVectorStore(
|
self.vector_store = QdrantVectorStore(
|
||||||
collection_name=config.collection_name,
|
collection_name=config.collection_name,
|
||||||
embedding=self.embeddings if self._embedder is None else None
|
embeddings=self.embeddings if self._embedder is None else None
|
||||||
)
|
)
|
||||||
logger.info("✅ 混合检索向量存储初始化成功(稠密+BM25稀疏)")
|
logger.info("✅ 混合检索向量存储初始化成功(稠密+BM25稀疏)")
|
||||||
|
|
||||||
@@ -178,7 +178,7 @@ class IndexBuilder:
|
|||||||
child_splitter=self.child_splitter,
|
child_splitter=self.child_splitter,
|
||||||
docstore=self.docstore,
|
docstore=self.docstore,
|
||||||
search_k=cfg.search_k,
|
search_k=cfg.search_k,
|
||||||
embeddings=self.embeddings if self.embedder is None else None,
|
embeddings=self.embeddings if self._embedder is None else None,
|
||||||
)
|
)
|
||||||
logger.info("ParentDocumentRetriever 初始化完成")
|
logger.info("ParentDocumentRetriever 初始化完成")
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ YELLOW='\033[1;33m'
|
|||||||
NC='\033[0m' # No Color
|
NC='\033[0m' # No Color
|
||||||
|
|
||||||
# 项目根目录
|
# 项目根目录
|
||||||
PROJECT_DIR="/root/projects/ailine"
|
PROJECT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
|
||||||
|
|
||||||
echo -e "${BLUE}========================================${NC}"
|
echo -e "${BLUE}========================================${NC}"
|
||||||
echo -e "${BLUE} AI Agent - 个人生活助手${NC}"
|
echo -e "${BLUE} AI Agent - 个人生活助手${NC}"
|
||||||
@@ -34,7 +34,7 @@ start_backend() {
|
|||||||
set +a
|
set +a
|
||||||
|
|
||||||
export PYTHONPATH="$PROJECT_DIR/backend"
|
export PYTHONPATH="$PROJECT_DIR/backend"
|
||||||
export BACKEND_PORT=10079
|
export BACKEND_PORT=8079
|
||||||
python -m app.backend &
|
python -m app.backend &
|
||||||
BACKEND_PID=$!
|
BACKEND_PID=$!
|
||||||
echo -e "${GREEN}✓ 后端服务已启动 (PID: $BACKEND_PID)${NC}"
|
echo -e "${GREEN}✓ 后端服务已启动 (PID: $BACKEND_PID)${NC}"
|
||||||
@@ -51,7 +51,7 @@ start_frontend() {
|
|||||||
set +a
|
set +a
|
||||||
|
|
||||||
export PYTHONPATH="$PROJECT_DIR/frontend/src"
|
export PYTHONPATH="$PROJECT_DIR/frontend/src"
|
||||||
export API_URL="http://127.0.0.1:10079/chat"
|
export API_URL="http://127.0.0.1:8079/chat"
|
||||||
streamlit run frontend/src/frontend_main.py --server.port 10501 --server.address 0.0.0.0 &
|
streamlit run frontend/src/frontend_main.py --server.port 10501 --server.address 0.0.0.0 &
|
||||||
FRONTEND_PID=$!
|
FRONTEND_PID=$!
|
||||||
echo -e "${GREEN}✓ 前端服务已启动 (PID: $FRONTEND_PID)${NC}"
|
echo -e "${GREEN}✓ 前端服务已启动 (PID: $FRONTEND_PID)${NC}"
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
"""
|
"""
|
||||||
测试重构后的 IndexBuilder 和 RAGRetriever
|
测试重构后的 IndexBuilder 和 RAG 检索
|
||||||
|
包括:索引构建、稠密检索、稀疏检索、混合检索、父子文档检索
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
@@ -8,15 +9,23 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
|
|
||||||
# 添加项目根目录到 Python 路径
|
# 添加项目根目录到 Python 路径
|
||||||
project_root = os.path.join(os.path.dirname(__file__), "..")
|
project_root = os.path.join(os.path.dirname(__file__), "..", "..")
|
||||||
|
sys.path.insert(0, os.path.join(project_root, "backend"))
|
||||||
sys.path.insert(0, project_root)
|
sys.path.insert(0, project_root)
|
||||||
|
|
||||||
from rag_indexer.index_builder import IndexBuilder
|
from rag_indexer.index_builder import IndexBuilder
|
||||||
from rag_indexer.splitters import SplitterType
|
from rag_indexer.splitters import SplitterType
|
||||||
|
|
||||||
|
from rag_core import QdrantVectorStore, get_sparse_embedder
|
||||||
|
from app.model_services import get_embedding_service
|
||||||
|
from qdrant_client import models
|
||||||
|
|
||||||
|
|
||||||
async def test_index_builder():
|
async def test_index_builder():
|
||||||
"""测试索引构建功能"""
|
"""测试索引构建功能"""
|
||||||
print("测试索引构建功能...")
|
print("="*70)
|
||||||
|
print("1. 测试索引构建功能...")
|
||||||
|
print("="*70)
|
||||||
|
|
||||||
# 创建 IndexBuilder 实例
|
# 创建 IndexBuilder 实例
|
||||||
builder = IndexBuilder(
|
builder = IndexBuilder(
|
||||||
@@ -27,7 +36,7 @@ async def test_index_builder():
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 测试文档路径
|
# 测试文档路径
|
||||||
test_file = os.path.join(os.path.dirname(__file__), "..", "data", "user_docs", "doublestory.txt")
|
test_file = os.path.join(project_root, "data", "user_docs", "doublestory.txt")
|
||||||
|
|
||||||
if os.path.exists(test_file):
|
if os.path.exists(test_file):
|
||||||
# 构建索引
|
# 构建索引
|
||||||
@@ -43,7 +52,260 @@ async def test_index_builder():
|
|||||||
|
|
||||||
# 关闭资源
|
# 关闭资源
|
||||||
builder.close()
|
builder.close()
|
||||||
print("\n测试完成")
|
print("\n索引构建测试完成")
|
||||||
|
return processed
|
||||||
|
|
||||||
|
|
||||||
|
def test_dense_retrieval():
|
||||||
|
"""测试稠密检索"""
|
||||||
|
print("\n" + "="*70)
|
||||||
|
print("2. 测试稠密检索...")
|
||||||
|
print("="*70)
|
||||||
|
|
||||||
|
# 获取嵌入服务
|
||||||
|
embeddings = get_embedding_service()
|
||||||
|
|
||||||
|
# 创建向量存储
|
||||||
|
vs = QdrantVectorStore(collection_name="rag_documents", embeddings=embeddings)
|
||||||
|
|
||||||
|
# 测试查询
|
||||||
|
query = "The Ant and the Grasshopper"
|
||||||
|
print(f"查询: {query}")
|
||||||
|
|
||||||
|
results = vs.similarity_search(query, k=3)
|
||||||
|
|
||||||
|
print(f"\n找到 {len(results)} 个结果:")
|
||||||
|
for i, doc in enumerate(results, 1):
|
||||||
|
print(f"\n{i}. (来源: {doc.metadata.get('source', 'unknown')})")
|
||||||
|
print(f" 元数据: {doc.metadata}")
|
||||||
|
content = doc.page_content.strip()
|
||||||
|
if len(content) > 200:
|
||||||
|
content = content[:200] + "..."
|
||||||
|
print(f" 内容: {content}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_sparse_retrieval_simple():
|
||||||
|
"""简单测试稀疏检索"""
|
||||||
|
print("\n" + "="*70)
|
||||||
|
print("3. 测试稀疏检索(BM25)...")
|
||||||
|
print("="*70)
|
||||||
|
|
||||||
|
# 获取嵌入服务和稀疏嵌入器
|
||||||
|
embeddings = get_embedding_service()
|
||||||
|
vs = QdrantVectorStore(collection_name="rag_documents", embeddings=embeddings)
|
||||||
|
client = vs.get_qdrant_client()
|
||||||
|
sparse_embedder = get_sparse_embedder()
|
||||||
|
|
||||||
|
# 测试查询 - 用关键词
|
||||||
|
query = "winter work food"
|
||||||
|
print(f"查询关键词: {query}")
|
||||||
|
|
||||||
|
# 生成稀疏查询向量
|
||||||
|
sparse_query = sparse_embedder.embed_query(query)
|
||||||
|
|
||||||
|
# 包装成 SparseVector 对象
|
||||||
|
sparse_vec = models.SparseVector(
|
||||||
|
indices=sparse_query["indices"],
|
||||||
|
values=sparse_query["values"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# 直接查询稀疏向量
|
||||||
|
response = client.query_points(
|
||||||
|
collection_name="rag_documents",
|
||||||
|
query=sparse_vec,
|
||||||
|
using="sparse",
|
||||||
|
limit=3,
|
||||||
|
with_payload=True
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"\n找到 {len(response.points)} 个结果:")
|
||||||
|
for i, point in enumerate(response.points, 1):
|
||||||
|
print(f"\n{i}. (分数: {point.score:.4f})")
|
||||||
|
text = point.payload.get("text", "")
|
||||||
|
metadata = {k: v for k, v in point.payload.items() if k != "text"}
|
||||||
|
print(f" 元数据: {metadata}")
|
||||||
|
content = text.strip()
|
||||||
|
if len(content) > 200:
|
||||||
|
content = content[:200] + "..."
|
||||||
|
print(f" 内容: {content}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_hybrid_retrieval_simple():
|
||||||
|
"""简单测试混合检索(稠密+稀疏 RRF 融合)"""
|
||||||
|
print("\n" + "="*70)
|
||||||
|
print("4. 测试混合检索(稠密+稀疏 RRF 融合)...")
|
||||||
|
print("="*70)
|
||||||
|
|
||||||
|
# 获取嵌入服务和稀疏嵌入器
|
||||||
|
embeddings = get_embedding_service()
|
||||||
|
vs = QdrantVectorStore(collection_name="rag_documents", embeddings=embeddings)
|
||||||
|
client = vs.get_qdrant_client()
|
||||||
|
sparse_embedder = get_sparse_embedder()
|
||||||
|
|
||||||
|
# 测试查询
|
||||||
|
query = "Ant and Grasshopper story"
|
||||||
|
print(f"查询: {query}")
|
||||||
|
|
||||||
|
# 生成双向量
|
||||||
|
dense_query = embeddings.embed_query(query)
|
||||||
|
sparse_query = sparse_embedder.embed_query(query)
|
||||||
|
sparse_vec = models.SparseVector(
|
||||||
|
indices=sparse_query["indices"],
|
||||||
|
values=sparse_query["values"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# 使用 Qdrant 的 query_points 做混合检索
|
||||||
|
response = client.query_points(
|
||||||
|
collection_name="rag_documents",
|
||||||
|
prefetch=[
|
||||||
|
models.Prefetch(
|
||||||
|
query=dense_query,
|
||||||
|
using="dense",
|
||||||
|
limit=3
|
||||||
|
),
|
||||||
|
models.Prefetch(
|
||||||
|
query=sparse_vec,
|
||||||
|
using="sparse",
|
||||||
|
limit=3
|
||||||
|
)
|
||||||
|
],
|
||||||
|
query=models.FusionQuery(fusion=models.Fusion.RRF),
|
||||||
|
limit=3,
|
||||||
|
with_payload=True
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"\n找到 {len(response.points)} 个结果:")
|
||||||
|
for i, point in enumerate(response.points, 1):
|
||||||
|
print(f"\n{i}. (RRF 融合分数: {point.score:.4f})")
|
||||||
|
text = point.payload.get("text", "")
|
||||||
|
metadata = {k: v for k, v in point.payload.items() if k != "text"}
|
||||||
|
print(f" 元数据: {metadata}")
|
||||||
|
content = text.strip()
|
||||||
|
if len(content) > 200:
|
||||||
|
content = content[:200] + "..."
|
||||||
|
print(f" 内容: {content}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_parent_child_retrieval_simple():
|
||||||
|
"""简单测试父子文档检索"""
|
||||||
|
print("\n" + "="*70)
|
||||||
|
print("5. 测试父子文档混合检索...")
|
||||||
|
print("="*70)
|
||||||
|
|
||||||
|
# 获取嵌入服务和稀疏嵌入器
|
||||||
|
embeddings = get_embedding_service()
|
||||||
|
vs = QdrantVectorStore(collection_name="rag_documents", embeddings=embeddings)
|
||||||
|
client = vs.get_qdrant_client()
|
||||||
|
sparse_embedder = get_sparse_embedder()
|
||||||
|
|
||||||
|
# 测试查询
|
||||||
|
query = "The Ant and the Grasshopper story moral"
|
||||||
|
print(f"查询: {query}")
|
||||||
|
|
||||||
|
# 生成双向量
|
||||||
|
dense_query = embeddings.embed_query(query)
|
||||||
|
sparse_query = sparse_embedder.embed_query(query)
|
||||||
|
sparse_vec = models.SparseVector(
|
||||||
|
indices=sparse_query["indices"],
|
||||||
|
values=sparse_query["values"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# 先做混合检索找到子文档
|
||||||
|
response = client.query_points(
|
||||||
|
collection_name="rag_documents",
|
||||||
|
prefetch=[
|
||||||
|
models.Prefetch(
|
||||||
|
query=dense_query,
|
||||||
|
using="dense",
|
||||||
|
limit=5
|
||||||
|
),
|
||||||
|
models.Prefetch(
|
||||||
|
query=sparse_vec,
|
||||||
|
using="sparse",
|
||||||
|
limit=5
|
||||||
|
)
|
||||||
|
],
|
||||||
|
query=models.FusionQuery(fusion=models.Fusion.RRF),
|
||||||
|
limit=5,
|
||||||
|
with_payload=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# 收集 parent_id
|
||||||
|
parent_score_map = {}
|
||||||
|
child_points = {}
|
||||||
|
for point in response.points:
|
||||||
|
parent_id = point.payload.get("parent_id", point.id)
|
||||||
|
score = point.score
|
||||||
|
if parent_id not in parent_score_map or score > parent_score_map[parent_id]:
|
||||||
|
parent_score_map[parent_id] = score
|
||||||
|
child_points[parent_id] = point
|
||||||
|
|
||||||
|
parent_ids = list(parent_score_map.keys())
|
||||||
|
|
||||||
|
print(f"\n找到 {len(parent_ids)} 个不同的 parent_id:")
|
||||||
|
|
||||||
|
# 查找父文档
|
||||||
|
if parent_ids:
|
||||||
|
parent_docs = client.retrieve(
|
||||||
|
collection_name="rag_documents",
|
||||||
|
ids=parent_ids,
|
||||||
|
with_payload=True
|
||||||
|
)
|
||||||
|
|
||||||
|
found_parent_ids = {p.id for p in parent_docs}
|
||||||
|
|
||||||
|
# 准备结果列表
|
||||||
|
results = []
|
||||||
|
for p in parent_docs:
|
||||||
|
score = parent_score_map[p.id]
|
||||||
|
results.append((p, score))
|
||||||
|
|
||||||
|
# 处理没找到父文档的情况 - 用子文档代替
|
||||||
|
missing = set(parent_ids) - found_parent_ids
|
||||||
|
for parent_id in missing:
|
||||||
|
child_point = child_points[parent_id]
|
||||||
|
print(f"\n注意: parent_id {parent_id} 未找到,使用子文档代替")
|
||||||
|
results.append((child_point, parent_score_map[parent_id]))
|
||||||
|
|
||||||
|
# 按分数排序
|
||||||
|
results.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
|
||||||
|
# 显示
|
||||||
|
print(f"\n共 {len(results)} 个结果(去重后):")
|
||||||
|
for i, (point, score) in enumerate(results[:3], 1):
|
||||||
|
print(f"\n{i}. (分数: {score:.4f})")
|
||||||
|
text = point.payload.get("text", "")
|
||||||
|
metadata = {k: v for k, v in point.payload.items() if k != "text"}
|
||||||
|
print(f" 元数据: {metadata}")
|
||||||
|
content = text.strip()
|
||||||
|
if len(content) > 400:
|
||||||
|
content = content[:400] + "..."
|
||||||
|
print(f" 内容: {content}")
|
||||||
|
else:
|
||||||
|
print("\n未找到结果")
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""主测试函数"""
|
||||||
|
# 1. 先构建索引
|
||||||
|
await test_index_builder()
|
||||||
|
|
||||||
|
# 2. 测试稠密检索
|
||||||
|
test_dense_retrieval()
|
||||||
|
|
||||||
|
# 3. 测试稀疏检索
|
||||||
|
test_sparse_retrieval_simple()
|
||||||
|
|
||||||
|
# 4. 测试混合检索
|
||||||
|
test_hybrid_retrieval_simple()
|
||||||
|
|
||||||
|
# 5. 测试父子文档检索
|
||||||
|
test_parent_child_retrieval_simple()
|
||||||
|
|
||||||
|
print("\n" + "="*70)
|
||||||
|
print("所有测试完成!")
|
||||||
|
print("="*70)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
asyncio.run(test_index_builder())
|
asyncio.run(main())
|
||||||
|
|||||||
Reference in New Issue
Block a user