refactor: 重构RAG核心组件,简化代码结构和测试文件
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 6m53s
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 6m53s
This commit is contained in:
@@ -69,7 +69,7 @@ LOCAL_MODEL_NAME = _get_str("LOCAL_MODEL_NAME") or "gemma-4-E4B-it"
|
||||
# ========== llama.cpp 服务配置(URL + API密钥 配对) ==========
|
||||
# 主 LLM 服务
|
||||
VLLM_BASE_URL = _get_str("VLLM_BASE_URL")
|
||||
LLM_API_KEY = _get_str("LLAMACPP_API_KEY")
|
||||
LLM_API_KEY = _get_str("LLM_API_KEY")
|
||||
|
||||
# Embedding 服务 (用于 Mem0 的向量化)
|
||||
LLAMACPP_EMBEDDING_URL = _get_str("LLAMACPP_EMBEDDING_URL")
|
||||
@@ -78,6 +78,26 @@ LLAMACPP_API_KEY = _get_str("LLAMACPP_API_KEY")
|
||||
# Reranker 服务
|
||||
LLAMACPP_RERANKER_URL = _get_str("LLAMACPP_RERANKER_URL")
|
||||
|
||||
# ========== 小模型配置(查询改写、意图分类等简单任务) ==========
|
||||
# 默认复用大模型配置,后续可单独配置
|
||||
# 本地小模型(默认复用 VLLM 配置)
|
||||
SMALL_VLLM_BASE_URL = _get_str("SMALL_VLLM_BASE_URL")
|
||||
SMALL_LLM_API_KEY = _get_str("SMALL_LLM_API_KEY")
|
||||
SMALL_LOCAL_MODEL_NAME = _get_str("SMALL_LOCAL_MODEL_NAME") or LOCAL_MODEL_NAME
|
||||
# 如果小模型没单独配置,用大模型的配置
|
||||
if not SMALL_VLLM_BASE_URL:
|
||||
SMALL_VLLM_BASE_URL = VLLM_BASE_URL
|
||||
if not SMALL_LLM_API_KEY:
|
||||
SMALL_LLM_API_KEY = LLM_API_KEY
|
||||
|
||||
# DeepSeek 小模型(默认复用 DeepSeek 配置)
|
||||
SMALL_DEEPSEEK_API_KEY = _get_str("SMALL_DEEPSEEK_API_KEY")
|
||||
SMALL_DEEPSEEK_MODEL = _get_str("SMALL_DEEPSEEK_MODEL") or "deepseek-chat"
|
||||
SMALL_DEEPSEEK_API_BASE = _get_str("SMALL_DEEPSEEK_API_BASE") or "https://api.deepseek.com"
|
||||
# 如果小模型没单独配置,用大模型的配置
|
||||
if not SMALL_DEEPSEEK_API_KEY:
|
||||
SMALL_DEEPSEEK_API_KEY = DEEPSEEK_API_KEY
|
||||
|
||||
|
||||
# ========== Qdrant 向量数据库配置(URL + API密钥 配对) ==========
|
||||
QDRANT_URL = _get_str("QDRANT_URL")
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Optional, Dict, Any
|
||||
import sys
|
||||
import os
|
||||
|
||||
from backend.app.model_services.chat_services import get_chat_service
|
||||
from backend.app.model_services.chat_services import get_small_llm_service
|
||||
|
||||
|
||||
class IntentType(Enum):
|
||||
@@ -33,7 +33,7 @@ class IntentClassifier:
|
||||
"""意图分类器"""
|
||||
|
||||
def __init__(self):
|
||||
self.llm = get_chat_service()
|
||||
self.llm = get_small_llm_service()
|
||||
self._intent_examples = self._build_examples()
|
||||
|
||||
def _build_examples(self) -> str:
|
||||
|
||||
@@ -19,7 +19,7 @@ from app.main_graph.utils.retry_utils import (
|
||||
)
|
||||
|
||||
# 真正导入和利用已有 RAG 代码
|
||||
from app.rag.tools import create_rag_tool_sync
|
||||
from app.rag.tools import create_rag_tool
|
||||
from app.rag.pipeline import RAGPipeline
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# app/rag_initializer.py
|
||||
from app.rag.tools import create_rag_tool_sync, create_rag_tool_async
|
||||
from rag_core import create_parent_retriever
|
||||
from app.rag.tools import create_rag_tool
|
||||
from app.rag.retriever import create_parent_hybrid_retriever
|
||||
from app.model_services import get_embedding_service
|
||||
from app.logger import info, warning
|
||||
|
||||
@@ -10,17 +10,17 @@ async def init_rag_tool(local_llm_creator):
|
||||
info("🔄 正在初始化 RAG 检索系统...")
|
||||
# 使用统一的嵌入服务获取接口
|
||||
embeddings = get_embedding_service()
|
||||
retriever = create_parent_retriever(
|
||||
retriever = create_parent_hybrid_retriever(
|
||||
collection_name="rag_documents",
|
||||
search_k=5,
|
||||
embeddings=embeddings
|
||||
)
|
||||
rewrite_llm = local_llm_creator()
|
||||
rag_tool = create_rag_tool_async(
|
||||
rag_tool = create_rag_tool(
|
||||
retriever, rewrite_llm,
|
||||
num_queries=3, rerank_top_n=5
|
||||
)
|
||||
info("✅ RAG 检索工具初始化成功(异步版本)")
|
||||
info("✅ RAG 检索工具初始化成功(全异步版本)")
|
||||
return rag_tool
|
||||
except Exception as e:
|
||||
warning(f"⚠️ RAG 检索工具初始化失败: {e}")
|
||||
|
||||
@@ -6,9 +6,11 @@
|
||||
|
||||
from .embedding_services import get_embedding_service
|
||||
from .rerank_services import get_rerank_service, BaseRerankService
|
||||
from .chat_services import get_small_llm_service
|
||||
|
||||
__all__ = [
|
||||
"get_embedding_service",
|
||||
"get_rerank_service",
|
||||
"get_small_llm_service",
|
||||
"BaseRerankService"
|
||||
]
|
||||
|
||||
@@ -219,51 +219,98 @@ class DeepSeekChatProvider(BaseServiceProvider[BaseChatModel]):
|
||||
|
||||
# ========== 轻量级模型 Provider ==========
|
||||
|
||||
class ZhipuSmallModelProvider(BaseServiceProvider[BaseChatModel]):
|
||||
class LocalSmallModelProvider(BaseServiceProvider[BaseChatModel]):
|
||||
"""
|
||||
智谱 AI 轻量级模型服务提供者(用于意图分类等简单任务)
|
||||
使用 glm-5.1-flash 或其他小模型
|
||||
本地轻量级模型服务提供者(用于查询改写、意图分类等简单任务)
|
||||
使用小模型独立配置
|
||||
"""
|
||||
|
||||
def __init__(self, model: str = "glm-5.1-flash"):
|
||||
super().__init__("zhipu_small")
|
||||
self._model = model
|
||||
def __init__(self, model: str = None):
|
||||
from app.config import SMALL_LOCAL_MODEL_NAME, SMALL_VLLM_BASE_URL, SMALL_LLM_API_KEY
|
||||
super().__init__("local_small")
|
||||
self._model = model or SMALL_LOCAL_MODEL_NAME
|
||||
self._base_url = SMALL_VLLM_BASE_URL
|
||||
self._api_key = SMALL_LLM_API_KEY
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""检查智谱轻量模型服务是否可用"""
|
||||
if not ZHIPUAI_API_KEY:
|
||||
logger.warning("ZHIPUAI_API_KEY 未配置,轻量模型不可用")
|
||||
"""检查本地小模型服务是否可用"""
|
||||
if not self._base_url:
|
||||
logger.warning("SMALL_VLLM_BASE_URL 未配置,本地小模型不可用")
|
||||
return False
|
||||
logger.info(f"智谱轻量模型配置正确: {self._model}")
|
||||
|
||||
try:
|
||||
# 先测试主机名能否解析
|
||||
import httpx
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed_url = urlparse(self._base_url)
|
||||
host = parsed_url.hostname
|
||||
port = parsed_url.port or (80 if parsed_url.scheme == 'http' else 443)
|
||||
|
||||
# 测试能否建立 TCP 连接(快速失败)
|
||||
import socket
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.settimeout(2.0)
|
||||
try:
|
||||
sock.connect((host, port))
|
||||
sock.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"本地小模型服务无法连接: {host}:{port} - {e}")
|
||||
return False
|
||||
|
||||
# 再尝试调用简单的 API
|
||||
client = httpx.Client(base_url=self._base_url.rstrip('/'), timeout=5.0)
|
||||
headers = {}
|
||||
if self._api_key:
|
||||
headers["Authorization"] = f"Bearer {self._api_key}"
|
||||
|
||||
try:
|
||||
response = client.get("/models", headers=headers)
|
||||
if response.status_code == 200:
|
||||
logger.info(f"本地小模型服务可用: {self._model}")
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.warning(f"本地小模型服务响应异常")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning(f"本地小模型服务不可用: {e}")
|
||||
return False
|
||||
|
||||
def get_service(self) -> BaseChatModel:
|
||||
"""获取智谱轻量模型服务"""
|
||||
"""获取本地小模型服务"""
|
||||
if self._service_instance is None:
|
||||
from langchain_community.chat_models import ChatZhipuAI
|
||||
self._service_instance = ChatZhipuAI(
|
||||
from langchain_openai import ChatOpenAI
|
||||
from pydantic import SecretStr
|
||||
|
||||
self._service_instance = ChatOpenAI(
|
||||
base_url=self._base_url,
|
||||
api_key=SecretStr(self._api_key) if self._api_key else SecretStr(""),
|
||||
model=self._model,
|
||||
api_key=ZHIPUAI_API_KEY,
|
||||
temperature=0.1,
|
||||
max_tokens=2048,
|
||||
timeout=30.0,
|
||||
max_retries=2,
|
||||
streaming=False
|
||||
streaming=False,
|
||||
)
|
||||
return self._service_instance
|
||||
|
||||
|
||||
class DeepSeekSmallModelProvider(BaseServiceProvider[BaseChatModel]):
|
||||
"""
|
||||
DeepSeek 轻量级模型服务提供者(备选)
|
||||
DeepSeek 轻量级模型服务提供者(用于查询改写、意图分类等简单任务)
|
||||
使用小模型独立配置
|
||||
"""
|
||||
|
||||
def __init__(self, model: str = "deepseek-chat"):
|
||||
def __init__(self, model: str = None):
|
||||
from app.config import SMALL_DEEPSEEK_MODEL, SMALL_DEEPSEEK_API_KEY, SMALL_DEEPSEEK_API_BASE
|
||||
super().__init__("deepseek_small")
|
||||
self._model = model
|
||||
self._model = model or SMALL_DEEPSEEK_MODEL
|
||||
self._api_key = SMALL_DEEPSEEK_API_KEY
|
||||
self._api_base = SMALL_DEEPSEEK_API_BASE
|
||||
|
||||
def is_available(self) -> bool:
|
||||
if not DEEPSEEK_API_KEY:
|
||||
logger.warning("DEEPSEEK_API_KEY 未配置")
|
||||
if not self._api_key:
|
||||
logger.warning("SMALL_DEEPSEEK_API_KEY 未配置")
|
||||
return False
|
||||
logger.info(f"DeepSeek 轻量模型配置正确: {self._model}")
|
||||
return True
|
||||
@@ -274,8 +321,8 @@ class DeepSeekSmallModelProvider(BaseServiceProvider[BaseChatModel]):
|
||||
from pydantic import SecretStr
|
||||
|
||||
self._service_instance = ChatOpenAI(
|
||||
base_url="https://api.deepseek.com",
|
||||
api_key=SecretStr(DEEPSEEK_API_KEY),
|
||||
base_url=self._api_base,
|
||||
api_key=SecretStr(self._api_key),
|
||||
model=self._model,
|
||||
temperature=0.1,
|
||||
max_tokens=2048,
|
||||
@@ -339,20 +386,17 @@ def get_all_chat_services() -> Dict[str, BaseChatModel]:
|
||||
|
||||
def get_small_llm_service() -> BaseChatModel:
|
||||
"""
|
||||
获取轻量级大模型服务(用于意图分类等简单任务)
|
||||
优先顺序: zhipu_small -> deepseek_small -> (降级到 get_chat_service)
|
||||
获取轻量级大模型服务(用于查询改写、意图分类等简单任务)
|
||||
优先顺序: 本地模型 -> DeepSeek 小模型
|
||||
⚠️ 注意:小模型任务不降级到大模型,避免不必要的 token 消耗!
|
||||
|
||||
Returns:
|
||||
BaseChatModel: LangChain 兼容的 ChatModel 实例
|
||||
"""
|
||||
def _create_small_chain():
|
||||
primary = ZhipuSmallModelProvider()
|
||||
primary = LocalSmallModelProvider()
|
||||
fallbacks = [DeepSeekSmallModelProvider()]
|
||||
return FallbackServiceChain(primary, fallbacks)
|
||||
|
||||
try:
|
||||
chain = SingletonServiceManager.get_or_create("small_llm_chain", _create_small_chain)
|
||||
return chain.get_available_service()
|
||||
except Exception as e:
|
||||
logger.warning(f"轻量模型初始化失败,降级到默认大模型: {e}")
|
||||
return get_chat_service()
|
||||
|
||||
@@ -42,7 +42,7 @@ from .rerank import DocumentReranker, create_document_reranker
|
||||
from .query_transform import MultiQueryGenerator
|
||||
from .fusion import reciprocal_rank_fusion
|
||||
from .pipeline import RAGPipeline
|
||||
from .tools import create_rag_tool_sync
|
||||
from .tools import create_rag_tool
|
||||
|
||||
|
||||
__all__ = [
|
||||
@@ -64,5 +64,5 @@ __all__ = [
|
||||
"RAGPipeline",
|
||||
|
||||
# 工具创建(供 Agent 使用)
|
||||
"create_rag_tool_sync",
|
||||
"create_rag_tool",
|
||||
]
|
||||
@@ -13,7 +13,7 @@ from typing import List, Optional
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
|
||||
from app.model_services import get_rerank_service
|
||||
from app.model_services import get_rerank_service, get_small_llm_service
|
||||
from app.rag.rerank import create_document_reranker
|
||||
from app.rag.query_transform import MultiQueryGenerator
|
||||
from app.rag.fusion import reciprocal_rank_fusion
|
||||
@@ -31,7 +31,7 @@ class RAGPipeline:
|
||||
def __init__(
|
||||
self,
|
||||
retriever=None,
|
||||
llm: Optional[BaseLanguageModel] = None,
|
||||
llm: Optional[BaseLanguageModel] = "default_small",
|
||||
num_queries: int = 3,
|
||||
rerank_top_n: int = 5,
|
||||
collection_name: str = "rag_documents",
|
||||
@@ -41,6 +41,9 @@ class RAGPipeline:
|
||||
retriever: 基础检索器对象,需实现 ainvoke(query) 异步方法。
|
||||
如果不提供,会自动创建默认的父子文档混合检索器。
|
||||
llm: 用于生成多路查询的语言模型。
|
||||
- "default_small": (默认) 使用小模型(本地 + DeepSeek)
|
||||
- None / False: 不做查询改写
|
||||
- BaseLanguageModel 实例: 自定义模型
|
||||
num_queries: 生成的查询变体数量。
|
||||
rerank_top_n: 最终返回的文档数量。
|
||||
collection_name: Qdrant 集合名称(仅当 retriever 未提供时使用)。
|
||||
@@ -54,12 +57,25 @@ class RAGPipeline:
|
||||
else:
|
||||
self.retriever = retriever
|
||||
|
||||
# 处理 llm 参数
|
||||
if llm == "default_small":
|
||||
try:
|
||||
self.llm = get_small_llm_service()
|
||||
except Exception as e:
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(f"小模型初始化失败,将不做查询改写: {e}")
|
||||
self.llm = None
|
||||
elif llm in (None, False):
|
||||
self.llm = None
|
||||
else:
|
||||
self.llm = llm
|
||||
|
||||
self.num_queries = num_queries
|
||||
self.rerank_top_n = rerank_top_n
|
||||
|
||||
# 初始化组件 - 使用统一的重排服务获取接口
|
||||
self.query_generator = MultiQueryGenerator(llm=llm, num_queries=num_queries) if llm else None
|
||||
self.query_generator = MultiQueryGenerator(llm=self.llm, num_queries=num_queries) if self.llm else None
|
||||
self.reranker = create_document_reranker()
|
||||
|
||||
async def aretrieve(self, query: str) -> List[Document]:
|
||||
@@ -103,10 +119,6 @@ class RAGPipeline:
|
||||
|
||||
return final_docs
|
||||
|
||||
def retrieve(self, query: str) -> List[Document]:
|
||||
"""同步检索入口(内部调用异步方法)"""
|
||||
return asyncio.run(self.aretrieve(query))
|
||||
|
||||
def format_context(self, documents: List[Document]) -> str:
|
||||
"""
|
||||
将文档列表格式化为上下文字符串
|
||||
@@ -129,7 +141,7 @@ class RAGPipeline:
|
||||
|
||||
def create_rag_pipeline(
|
||||
collection_name: str = "rag_documents",
|
||||
llm: Optional[BaseLanguageModel] = None,
|
||||
llm: Optional[BaseLanguageModel] = "default_small",
|
||||
num_queries: int = 3,
|
||||
rerank_top_n: int = 5,
|
||||
) -> RAGPipeline:
|
||||
@@ -138,7 +150,10 @@ def create_rag_pipeline(
|
||||
|
||||
Args:
|
||||
collection_name: Qdrant 集合名称
|
||||
llm: 用于生成多路查询的语言模型
|
||||
llm: 用于生成多路查询的语言模型。
|
||||
- "default_small": (默认) 使用小模型(本地 + DeepSeek)
|
||||
- None / False: 不做查询改写
|
||||
- BaseLanguageModel 实例: 自定义模型
|
||||
num_queries: 生成的查询变体数量
|
||||
rerank_top_n: 最终返回的文档数量
|
||||
|
||||
|
||||
@@ -63,14 +63,32 @@ class HybridRetriever(BaseRetriever):
|
||||
self._client = vector_store.get_async_qdrant_client()
|
||||
self._sparse_embedder = get_sparse_embedder()
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: Any = None
|
||||
) -> List[Document]:
|
||||
"""
|
||||
同步检索(不推荐使用,仅供兼容性)
|
||||
|
||||
注意:在异步环境中请使用 _aget_relevant_documents 或 ainvoke
|
||||
"""
|
||||
import asyncio
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
# 已有事件循环,使用 create_task
|
||||
task = loop.create_task(self._aget_relevant_documents(query))
|
||||
return loop.run_until_complete(task)
|
||||
except RuntimeError:
|
||||
# 没有事件循环,创建新的
|
||||
return asyncio.run(self._aget_relevant_documents(query))
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self, query: str, **kwargs
|
||||
self, query: str, *, run_manager: Any = None
|
||||
) -> List[Document]:
|
||||
"""
|
||||
异步混合检索相关文档
|
||||
"""
|
||||
# 1. 生成查询向量
|
||||
dense_query = await self._vector_store._aembed_query(query)
|
||||
dense_query = await self._vector_store.aembed_query(query)
|
||||
sparse_query = self._sparse_embedder.embed_query(query)
|
||||
sparse_vec = models.SparseVector(
|
||||
indices=sparse_query["indices"],
|
||||
@@ -106,7 +124,7 @@ class HybridRetriever(BaseRetriever):
|
||||
)
|
||||
results.append(doc)
|
||||
|
||||
debug(f"混合检索返回 %d 个文档", len(results))
|
||||
debug(f"混合检索返回 {len(results)} 个文档")
|
||||
return results
|
||||
|
||||
|
||||
@@ -150,14 +168,30 @@ class ParentHybridRetriever(BaseRetriever):
|
||||
self._sparse_embedder = get_sparse_embedder()
|
||||
self._docstore = docstore
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: Any = None
|
||||
) -> List[Document]:
|
||||
"""
|
||||
同步检索(不推荐使用,仅供兼容性)
|
||||
|
||||
注意:在异步环境中请使用 _aget_relevant_documents 或 ainvoke
|
||||
"""
|
||||
import asyncio
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
task = loop.create_task(self._aget_relevant_documents(query))
|
||||
return loop.run_until_complete(task)
|
||||
except RuntimeError:
|
||||
return asyncio.run(self._aget_relevant_documents(query))
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self, query: str, **kwargs
|
||||
self, query: str, *, run_manager: Any = None
|
||||
) -> List[Document]:
|
||||
"""
|
||||
异步检索相关父文档
|
||||
"""
|
||||
# 1. 生成查询向量
|
||||
dense_query = await self._vector_store._aembed_query(query)
|
||||
dense_query = await self._vector_store.aembed_query(query)
|
||||
sparse_query = self._sparse_embedder.embed_query(query)
|
||||
sparse_vec = models.SparseVector(
|
||||
indices=sparse_query["indices"],
|
||||
@@ -228,7 +262,7 @@ class ParentHybridRetriever(BaseRetriever):
|
||||
found_parent_ids.add(point.id)
|
||||
|
||||
except Exception as e:
|
||||
warning(f"从 Qdrant 查询父文档失败: %s", e)
|
||||
warning(f"从 Qdrant 查询父文档失败: {e}")
|
||||
|
||||
# 6. 如果有 docstore,尝试从 docstore 查询剩余的父文档
|
||||
if self._docstore and len(found_parent_ids) < len(parent_ids):
|
||||
@@ -240,12 +274,12 @@ class ParentHybridRetriever(BaseRetriever):
|
||||
parent_docs.append(doc)
|
||||
found_parent_ids.add(doc_id)
|
||||
except Exception as e:
|
||||
warning(f"从 docstore 查询父文档失败: %s", e)
|
||||
warning(f"从 docstore 查询父文档失败: {e}")
|
||||
|
||||
# 7. 降级:对于仍未找到的父文档,用子文档本身代替
|
||||
missing_parent_ids = parent_ids - found_parent_ids
|
||||
if missing_parent_ids:
|
||||
warning(f"以下 parent_id 未找到对应的父文档,将返回子文档本身: %s", missing_parent_ids)
|
||||
warning(f"以下 parent_id 未找到对应的父文档,将返回子文档本身: {missing_parent_ids}")
|
||||
for parent_id in missing_parent_ids:
|
||||
child_point = child_point_map.get(parent_id)
|
||||
if child_point:
|
||||
@@ -264,7 +298,7 @@ class ParentHybridRetriever(BaseRetriever):
|
||||
parent_docs_with_scores.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
final_docs = [doc for doc, _ in parent_docs_with_scores[:self.search_k]]
|
||||
debug(f"父子文档混合检索返回 %d 个父文档", len(final_docs))
|
||||
debug(f"父子文档混合检索返回 {len(final_docs)} 个父文档")
|
||||
|
||||
return final_docs
|
||||
|
||||
@@ -291,7 +325,7 @@ def create_hybrid_retriever(
|
||||
embeddings = get_embedding_service()
|
||||
info("使用统一嵌入服务(本地 llama.cpp → 智谱 API 自动降级)")
|
||||
|
||||
vector_store = QdrantHybridStore(collection_name=collection_name, embeddings=embeddings)
|
||||
vector_store = QdrantHybridStore(collection_name=collection_name)
|
||||
|
||||
try:
|
||||
vector_store.get_client().get_collection(collection_name)
|
||||
@@ -336,7 +370,7 @@ def create_parent_hybrid_retriever(
|
||||
embeddings = get_embedding_service()
|
||||
info("使用统一嵌入服务(本地 llama.cpp → 智谱 API 自动降级)")
|
||||
|
||||
vector_store = QdrantHybridStore(collection_name=collection_name, embeddings=embeddings)
|
||||
vector_store = QdrantHybridStore(collection_name=collection_name)
|
||||
|
||||
try:
|
||||
vector_store.get_client().get_collection(collection_name)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
RAG 工具模块
|
||||
RAG 工具模块(完全异步)
|
||||
|
||||
将检索功能封装为 LangChain Tool,供 Agent 调用。
|
||||
采用固定流水线:多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档。
|
||||
@@ -13,78 +13,24 @@ from langchain_core.retrievers import BaseRetriever
|
||||
from app.rag.pipeline import RAGPipeline, create_rag_pipeline
|
||||
|
||||
|
||||
def create_rag_tool_sync(
|
||||
def create_rag_tool(
|
||||
retriever: Optional[BaseRetriever] = None,
|
||||
llm: Optional[BaseLanguageModel] = None,
|
||||
llm: Optional[BaseLanguageModel] = "default_small",
|
||||
num_queries: int = 3,
|
||||
rerank_top_n: int = 5,
|
||||
collection_name: str = "rag_documents",
|
||||
) -> Callable:
|
||||
"""
|
||||
创建一个配置好的 RAG 检索工具(同步版本)。
|
||||
创建一个配置好的 RAG 检索工具(完全异步)。
|
||||
|
||||
默认使用混合检索(稠密+BM25稀疏)+ 父子文档模式。
|
||||
|
||||
Args:
|
||||
retriever: 基础检索器对象(可选,不提供则自动创建)
|
||||
llm: 用于生成多路查询的语言模型(可选)
|
||||
num_queries: 生成的查询变体数量
|
||||
rerank_top_n: 最终返回的文档数量
|
||||
collection_name: Qdrant 集合名称
|
||||
|
||||
Returns:
|
||||
LangChain Tool 函数
|
||||
"""
|
||||
pipeline = RAGPipeline(
|
||||
retriever=retriever,
|
||||
llm=llm,
|
||||
num_queries=num_queries,
|
||||
rerank_top_n=rerank_top_n,
|
||||
collection_name=collection_name,
|
||||
)
|
||||
|
||||
@tool
|
||||
def search_knowledge_base_sync(query: str) -> str:
|
||||
"""
|
||||
在知识库中搜索与查询相关的文档片段。
|
||||
|
||||
使用混合检索(稠密向量语义 + BM25 关键词)+ 父子文档模式,
|
||||
检索效果最优。
|
||||
|
||||
Args:
|
||||
query: 用户提出的问题或查询字符串
|
||||
|
||||
Returns:
|
||||
格式化后的相关文档内容
|
||||
"""
|
||||
try:
|
||||
documents = pipeline.retrieve(query)
|
||||
if not documents:
|
||||
return f"在知识库 '{collection_name}' 中未找到与 '{query}' 相关的信息。"
|
||||
|
||||
context = pipeline.format_context(documents)
|
||||
return context
|
||||
except Exception as e:
|
||||
return f"检索过程中发生错误: {str(e)}"
|
||||
|
||||
return search_knowledge_base_sync
|
||||
|
||||
|
||||
def create_rag_tool_async(
|
||||
retriever: Optional[BaseRetriever] = None,
|
||||
llm: Optional[BaseLanguageModel] = None,
|
||||
num_queries: int = 3,
|
||||
rerank_top_n: int = 5,
|
||||
collection_name: str = "rag_documents",
|
||||
) -> Callable:
|
||||
"""
|
||||
创建一个配置好的 RAG 检索工具(异步版本)。
|
||||
|
||||
默认使用混合检索(稠密+BM25稀疏)+ 父子文档模式。
|
||||
|
||||
Args:
|
||||
retriever: 基础检索器对象(可选,不提供则自动创建)
|
||||
llm: 用于生成多路查询的语言模型(可选)
|
||||
llm: 用于生成多路查询的语言模型。
|
||||
- "default_small": (默认) 使用小模型(本地 + DeepSeek)
|
||||
- None / False: 不做查询改写
|
||||
- BaseLanguageModel 实例: 自定义模型
|
||||
num_queries: 生成的查询变体数量
|
||||
rerank_top_n: 最终返回的文档数量
|
||||
collection_name: Qdrant 集合名称
|
||||
@@ -101,9 +47,9 @@ def create_rag_tool_async(
|
||||
)
|
||||
|
||||
@tool
|
||||
async def search_knowledge_base_async(query: str) -> str:
|
||||
async def search_knowledge_base(query: str) -> str:
|
||||
"""
|
||||
在知识库中搜索与查询相关的文档片段(异步版本)。
|
||||
在知识库中搜索与查询相关的文档片段(完全异步)。
|
||||
|
||||
使用混合检索(稠密向量语义 + BM25 关键词)+ 父子文档模式,
|
||||
检索效果最优。
|
||||
@@ -124,30 +70,4 @@ def create_rag_tool_async(
|
||||
except Exception as e:
|
||||
return f"检索过程中发生错误: {str(e)}"
|
||||
|
||||
return search_knowledge_base_async
|
||||
|
||||
|
||||
def create_rag_tool(
|
||||
collection_name: str = "rag_documents",
|
||||
llm: Optional[BaseLanguageModel] = None,
|
||||
num_queries: int = 3,
|
||||
rerank_top_n: int = 5,
|
||||
) -> Callable:
|
||||
"""
|
||||
创建 RAG 检索工具的便捷函数(同步版本)。
|
||||
|
||||
Args:
|
||||
collection_name: Qdrant 集合名称
|
||||
llm: 用于生成多路查询的语言模型(可选)
|
||||
num_queries: 生成的查询变体数量
|
||||
rerank_top_n: 最终返回的文档数量
|
||||
|
||||
Returns:
|
||||
LangChain Tool 函数
|
||||
"""
|
||||
return create_rag_tool_sync(
|
||||
collection_name=collection_name,
|
||||
llm=llm,
|
||||
num_queries=num_queries,
|
||||
rerank_top_n=rerank_top_n,
|
||||
)
|
||||
return search_knowledge_base
|
||||
|
||||
@@ -6,8 +6,13 @@ RAG Core - 公共 RAG 组件包
|
||||
from .embedders import get_embeddings, get_embedding_dimension
|
||||
from .vector_store import QdrantHybridStore
|
||||
from .sparse_embedder import BM25SparseEmbedder, get_sparse_embedder
|
||||
from .store import PostgresDocStore, create_docstore
|
||||
from .client import create_qdrant_client, create_async_qdrant_client
|
||||
from .doc_store import PostgresDocStore
|
||||
from .client import (
|
||||
create_qdrant_client,
|
||||
create_async_qdrant_client,
|
||||
create_docstore,
|
||||
get_docstore_uri
|
||||
)
|
||||
from .config import (
|
||||
QDRANT_URL,
|
||||
QDRANT_API_KEY,
|
||||
@@ -24,14 +29,15 @@ __all__ = [
|
||||
"QdrantHybridStore",
|
||||
"BM25SparseEmbedder",
|
||||
"get_sparse_embedder",
|
||||
"PostgresDocStore",
|
||||
"create_docstore",
|
||||
"get_docstore_uri",
|
||||
"create_qdrant_client",
|
||||
"create_async_qdrant_client",
|
||||
"QDRANT_URL",
|
||||
"QDRANT_API_KEY",
|
||||
"LLAMACPP_EMBEDDING_URL",
|
||||
"LLAMACPP_API_KEY",
|
||||
"DB_URI",
|
||||
"DOCSTORE_URI",
|
||||
"PostgresDocStore",
|
||||
"create_docstore",
|
||||
"create_qdrant_client",
|
||||
"create_async_qdrant_client",
|
||||
]
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
# rag_core/client.py
|
||||
import os
|
||||
from .config import QDRANT_URL, QDRANT_API_KEY
|
||||
from .config import QDRANT_URL, QDRANT_API_KEY, DOCSTORE_URI
|
||||
from qdrant_client import QdrantClient, AsyncQdrantClient
|
||||
from typing import Tuple
|
||||
from langchain_core.stores import BaseStore
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_qdrant_client(timeout: int = 300) -> QdrantClient:
|
||||
@@ -54,3 +59,47 @@ def create_async_qdrant_client(timeout: int = 300) -> AsyncQdrantClient:
|
||||
client_kwargs["api_key"] = QDRANT_API_KEY
|
||||
|
||||
return AsyncQdrantClient(**client_kwargs)
|
||||
|
||||
|
||||
def get_docstore_uri() -> str:
|
||||
"""获取 docstore 专用的数据库连接字符串(可与主库相同)"""
|
||||
return DOCSTORE_URI
|
||||
|
||||
|
||||
def create_docstore(
|
||||
table_name: str = "parent_documents",
|
||||
pool_config: dict | None = None,
|
||||
max_concurrency: int | None = None
|
||||
) -> Tuple[BaseStore, str]:
|
||||
"""
|
||||
工厂函数,创建 PostgreSQL 文档存储。
|
||||
|
||||
Args:
|
||||
table_name: PostgreSQL 表名(默认:parent_documents)
|
||||
pool_config: 连接池配置
|
||||
max_concurrency: 最大并发操作数,如果为 None 则不限制
|
||||
|
||||
Returns:
|
||||
元组 (存储实例, 连接字符串)
|
||||
|
||||
Raises:
|
||||
ImportError: 缺少必要的依赖
|
||||
|
||||
Example:
|
||||
>>> # 创建 PostgreSQL 存储
|
||||
>>> store, conn = create_docstore(
|
||||
... table_name="parent_docs",
|
||||
... max_concurrency=10
|
||||
... )
|
||||
"""
|
||||
from .doc_store import PostgresDocStore
|
||||
|
||||
conn_str = get_docstore_uri()
|
||||
store = PostgresDocStore(
|
||||
connection_string=conn_str,
|
||||
table_name=table_name,
|
||||
pool_config=pool_config,
|
||||
max_concurrency=max_concurrency
|
||||
)
|
||||
logger.info(f"PostgreSQL docstore 已创建: {table_name}")
|
||||
return store, conn_str
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""
|
||||
异步 PostgreSQL 存储实现 - 用于生产环境。
|
||||
异步 PostgreSQL 文档存储
|
||||
|
||||
使用 asyncpg 实现真正的异步 PostgreSQL 文档存储,支持高并发访问。
|
||||
用于 ParentDocumentRetriever 的父文档存储,支持高并发访问。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
@@ -16,6 +16,7 @@ import asyncpg
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PostgresDocStore(BaseStore[str, Any]):
|
||||
"""
|
||||
异步 PostgreSQL 文档存储实现。
|
||||
@@ -49,7 +50,7 @@ class PostgresDocStore(BaseStore[str, Any]):
|
||||
|
||||
Args:
|
||||
connection_string: PostgreSQL 连接 URL,格式:
|
||||
"postgresql://user:password@host:port/database?sslmode=disable"
|
||||
"postgresql://user:***@host:port/database?sslmode=disable"
|
||||
table_name: 存储表名,默认为 "parent_documents"
|
||||
pool_config: 连接池配置字典,包含:
|
||||
- min_size: 最小连接数(默认 2)
|
||||
@@ -57,18 +58,17 @@ class PostgresDocStore(BaseStore[str, Any]):
|
||||
max_concurrency: 最大并发操作数,如果为 None 则不限制
|
||||
|
||||
Raises:
|
||||
ImportError: 未安装 asyncpg 时抛出
|
||||
ImportError: 缺少必要的依赖
|
||||
|
||||
Example:
|
||||
>>> store = PostgresDocStore(
|
||||
... "postgresql://user:pass@localhost:5432/mydb",
|
||||
... "postgresql://user:***@localhost:5432/mydb",
|
||||
... table_name="parent_docs",
|
||||
... pool_config={"min_size": 5, "max_size": 20},
|
||||
... max_concurrency=10
|
||||
... )
|
||||
"""
|
||||
|
||||
|
||||
self.dsn = connection_string
|
||||
self.table_name = table_name
|
||||
self._pool: Optional["asyncpg.Pool"] = None
|
||||
@@ -244,3 +244,4 @@ class PostgresDocStore(BaseStore[str, Any]):
|
||||
注意:在异步环境中,请使用 aclose 方法。
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
"""
|
||||
文档存储模块 - 用于 ParentDocumentRetriever 的父文档存储。
|
||||
|
||||
提供 PostgreSQL 存储后端:
|
||||
- PostgresDocStore: PostgreSQL 数据库存储(生产环境)
|
||||
|
||||
示例用法:
|
||||
>>> from rag_core.store import create_docstore
|
||||
|
||||
>>> # 创建 PostgreSQL 存储
|
||||
>>> store, conn = create_docstore(
|
||||
... table_name="parent_docs"
|
||||
... )
|
||||
"""
|
||||
|
||||
|
||||
from .postgres import PostgresDocStore
|
||||
from .factory import create_docstore, get_docstore_uri
|
||||
|
||||
__version__ = "2.0.0"
|
||||
|
||||
__all__ = [
|
||||
# 具体实现
|
||||
"PostgresDocStore",
|
||||
|
||||
# 工厂函数
|
||||
"create_docstore",
|
||||
"get_docstore_uri",
|
||||
]
|
||||
@@ -1,56 +0,0 @@
|
||||
"""
|
||||
文档存储工厂 - 创建不同类型的存储实例。
|
||||
|
||||
提供统一的接口来创建本地文件存储或 PostgreSQL 存储。
|
||||
"""
|
||||
|
||||
import os
|
||||
from ..config import DOCSTORE_URI
|
||||
import logging
|
||||
from typing import Tuple
|
||||
|
||||
from langchain_core.stores import BaseStore
|
||||
from .postgres import PostgresDocStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_docstore_uri() -> str:
|
||||
"""获取 docstore 专用的数据库连接字符串(可与主库相同)"""
|
||||
return DOCSTORE_URI
|
||||
|
||||
|
||||
def create_docstore(
|
||||
table_name: str = "parent_documents",
|
||||
pool_config: dict | None = None,
|
||||
max_concurrency: int | None = None
|
||||
) -> Tuple[BaseStore, str]:
|
||||
"""
|
||||
工厂函数,创建 PostgreSQL 文档存储。
|
||||
|
||||
Args:
|
||||
table_name: PostgreSQL 表名(默认:parent_documents)
|
||||
pool_config: 连接池配置
|
||||
max_concurrency: 最大并发操作数,如果为 None 则不限制
|
||||
|
||||
Returns:
|
||||
元组 (存储实例, 连接字符串)
|
||||
|
||||
Raises:
|
||||
ImportError: 缺少必要的依赖
|
||||
|
||||
Example:
|
||||
>>> # 创建 PostgreSQL 存储
|
||||
>>> store, conn = create_docstore(
|
||||
... table_name="parent_docs",
|
||||
... max_concurrency=10
|
||||
... )
|
||||
"""
|
||||
conn_str = get_docstore_uri()
|
||||
store = PostgresDocStore(
|
||||
connection_string=conn_str,
|
||||
table_name=table_name,
|
||||
pool_config=pool_config,
|
||||
max_concurrency=max_concurrency
|
||||
)
|
||||
return store, conn_str
|
||||
@@ -33,8 +33,6 @@ class QdrantHybridStore:
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: str,
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
sparse_embedder: Optional[BM25SparseEmbedder] = None,
|
||||
):
|
||||
self.collection_name = collection_name
|
||||
self._client: Optional[QdrantClient] = None
|
||||
@@ -43,13 +41,10 @@ class QdrantHybridStore:
|
||||
self._last_connection_time: Optional[float] = None
|
||||
|
||||
# 稠密嵌入模型
|
||||
if embeddings is None:
|
||||
self.embeddings = get_embeddings()
|
||||
else:
|
||||
self.embeddings = embeddings
|
||||
|
||||
# 稀疏嵌入模型
|
||||
self.sparse_embedder = sparse_embedder or get_sparse_embedder()
|
||||
self.sparse_embedder = get_sparse_embedder()
|
||||
|
||||
# 集合初始化
|
||||
self.create_collection()
|
||||
@@ -176,7 +171,7 @@ class QdrantHybridStore:
|
||||
texts = [doc.page_content for doc in documents]
|
||||
|
||||
# 生成稠密向量
|
||||
dense_vectors = await self._aembed_texts(texts)
|
||||
dense_vectors = await self.aembed_documents(texts)
|
||||
|
||||
# 生成稀疏向量
|
||||
sparse_vectors = self.sparse_embedder.embed_documents(texts)
|
||||
@@ -210,14 +205,18 @@ class QdrantHybridStore:
|
||||
|
||||
return [p.id for p in points]
|
||||
|
||||
async def _aembed_texts(self, texts: List[str]) -> List[List[float]]:
|
||||
"""异步生成稠密向量(适配同步 Embeddings 接口)"""
|
||||
# 注意:LangChain 的 Embeddings 接口目前主要是同步的
|
||||
# 使用线程池或直接调用(如果 embedding 内部有异步支持)
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""异步生成文本列表的稠密向量"""
|
||||
import asyncio
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, self.embeddings.embed_documents, texts)
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
"""异步生成查询的稠密向量"""
|
||||
import asyncio
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, self.embeddings.embed_query, text)
|
||||
|
||||
# ---------- 异步检索方法 ----------
|
||||
async def asimilarity_search(self, query: str, k: int = 5) -> List[Document]:
|
||||
"""
|
||||
@@ -227,7 +226,7 @@ class QdrantHybridStore:
|
||||
client = self.get_async_client()
|
||||
|
||||
# 生成查询向量
|
||||
dense_query = await self._aembed_query(query)
|
||||
dense_query = await self.aembed_query(query)
|
||||
sparse_query = self.sparse_embedder.embed_query(query)
|
||||
sparse_vec = models.SparseVector(
|
||||
indices=sparse_query["indices"],
|
||||
@@ -264,12 +263,6 @@ class QdrantHybridStore:
|
||||
logger.debug("混合检索返回 %d 个文档", len(results))
|
||||
return results
|
||||
|
||||
async def _aembed_query(self, text: str) -> List[float]:
|
||||
"""异步生成查询稠密向量"""
|
||||
import asyncio
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, self.embeddings.embed_query, text)
|
||||
|
||||
# ---------- 同步管理方法(保留,用于初始化和管理) ----------
|
||||
def delete_collection(self):
|
||||
self.get_client().delete_collection(self.collection_name)
|
||||
|
||||
@@ -6,10 +6,6 @@ import asyncio
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 加载 .env 文件
|
||||
load_dotenv()
|
||||
|
||||
from rag_indexer.index_builder import IndexBuilder, IndexBuilderConfig
|
||||
from rag_indexer.splitters import SplitterType
|
||||
@@ -38,7 +34,7 @@ def get_input_path() -> Path:
|
||||
if len(sys.argv) > 1:
|
||||
return Path(sys.argv[1])
|
||||
# 默认测试路径(可按需修改)
|
||||
return Path("data/user_docs/doublestory.txt")
|
||||
return Path("data/corpus/三国演义.txt")
|
||||
|
||||
|
||||
async def main():
|
||||
|
||||
@@ -45,6 +45,11 @@ class IndexBuilderConfig:
|
||||
child_chunk_size: int = 200
|
||||
child_chunk_overlap: int = 20
|
||||
child_splitter_type: SplitterType = SplitterType.SEMANTIC # 子块默认语义切分
|
||||
# 子块语义切分参数
|
||||
child_buffer_size: int = 1
|
||||
child_breakpoint_threshold_type: str = "percentile"
|
||||
child_breakpoint_threshold_amount: float = 90 # 降低阈值,让切分更激进
|
||||
child_min_chunk_size: int = 50 # 降低最小块大小
|
||||
|
||||
# 检索参数
|
||||
search_k: int = 5
|
||||
@@ -86,7 +91,6 @@ class IndexBuilder:
|
||||
# 初始化向量存储(自动支持稠密+稀疏混合检索)
|
||||
self.vector_store = QdrantHybridStore(
|
||||
collection_name=config.collection_name,
|
||||
embeddings=self.embeddings,
|
||||
)
|
||||
logger.info("✅ 混合检索向量存储初始化成功(稠密+BM25稀疏)")
|
||||
|
||||
@@ -125,6 +129,10 @@ class IndexBuilder:
|
||||
self.child_splitter = get_splitter(
|
||||
SplitterType.SEMANTIC,
|
||||
embeddings=self.embeddings,
|
||||
buffer_size=cfg.child_buffer_size,
|
||||
breakpoint_threshold_type=cfg.child_breakpoint_threshold_type,
|
||||
breakpoint_threshold_amount=cfg.child_breakpoint_threshold_amount,
|
||||
min_chunk_size=cfg.child_min_chunk_size,
|
||||
**cfg.extra_splitter_kwargs
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -8,7 +8,6 @@ import os
|
||||
import sys
|
||||
|
||||
from backend.rag_core import QdrantHybridStore
|
||||
from backend.app.model_services import get_embedding_service
|
||||
|
||||
|
||||
async def delete_and_recreate():
|
||||
@@ -17,8 +16,7 @@ async def delete_and_recreate():
|
||||
print("删除旧集合并重新创建...")
|
||||
print("="*70)
|
||||
|
||||
embeddings = get_embedding_service()
|
||||
vs = QdrantHybridStore(collection_name="rag_documents", embeddings=embeddings)
|
||||
vs = QdrantHybridStore(collection_name="rag_documents")
|
||||
|
||||
# 删除旧集合
|
||||
try:
|
||||
18
tools/run.py
Normal file
18
tools/run.py
Normal file
@@ -0,0 +1,18 @@
|
||||
#!/usr/bin/env python3
|
||||
"""统一入口:设置路径后运行 RAG 索引构建 CLI"""
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 路径设置
|
||||
project_root = Path(__file__).resolve().parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
sys.path.insert(0, str(project_root / "backend"))
|
||||
load_dotenv(project_root / ".env")
|
||||
|
||||
if __name__ == "__main__":
|
||||
from rag_indexer.cli import main
|
||||
#from tools.test.test_rag_indexer_result import main
|
||||
#from tools.test.test_rag_pipeline import main
|
||||
import asyncio
|
||||
asyncio.run(main())
|
||||
@@ -1,75 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
检查 Qdrant 集合里的数据结构
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
|
||||
from backend.rag_core import QdrantHybridStore
|
||||
from backend.app.model_services import get_embedding_service
|
||||
|
||||
|
||||
def check_qdrant_data():
|
||||
"""检查 Qdrant 中的数据结构"""
|
||||
print("="*70)
|
||||
print("检查 Qdrant 中的数据结构...")
|
||||
print("="*70)
|
||||
|
||||
embeddings = get_embedding_service()
|
||||
vs = QdrantHybridStore(collection_name="rag_documents", embeddings=embeddings)
|
||||
client = vs.get_qdrant_client()
|
||||
|
||||
# 先获取几个点看看 payload 结构
|
||||
print("\n获取 5 个随机文档:")
|
||||
results = client.scroll(
|
||||
collection_name="rag_documents",
|
||||
limit=5,
|
||||
with_payload=True,
|
||||
with_vectors=True
|
||||
)
|
||||
|
||||
for i, point in enumerate(results[0], 1):
|
||||
print(f"\n{i}. ID: {point.id}")
|
||||
print(f" Payload: {point.payload}")
|
||||
print(f" Payload 键: {list(point.payload.keys())}")
|
||||
if "text" in point.payload:
|
||||
text = point.payload["text"]
|
||||
print(f" Text 长度: {len(text)}")
|
||||
print(f" Text 预览: {text[:150]}...")
|
||||
if "page_content" in point.payload:
|
||||
print(f" page_content: {point.payload['page_content'][:150]}...")
|
||||
|
||||
# 看看向量
|
||||
if point.vector:
|
||||
print(f" 向量存在: {type(point.vector)}")
|
||||
if isinstance(point.vector, dict):
|
||||
print(f" 向量键: {list(point.vector.keys())}")
|
||||
|
||||
|
||||
def check_sparse_embedder():
|
||||
"""检查稀疏嵌入器"""
|
||||
from backend.rag_core import get_sparse_embedder
|
||||
|
||||
print("\n" + "="*70)
|
||||
print("检查稀疏嵌入器...")
|
||||
print("="*70)
|
||||
|
||||
sparse_embedder = get_sparse_embedder()
|
||||
|
||||
print(f"\n稀疏嵌入器: {sparse_embedder}")
|
||||
print(f"Vocabulary 大小: {len(sparse_embedder.model.vocab)}")
|
||||
print(f"示例查询: '冬天 食物'")
|
||||
|
||||
# 用中文试试
|
||||
sparse_vec = sparse_embedder.embed_query("冬天 食物")
|
||||
print(f"\n生成的稀疏向量:")
|
||||
print(f" 索引数量: {len(sparse_vec['indices'])}")
|
||||
print(f" 索引: {sparse_vec['indices'][:10]}")
|
||||
print(f" 值: {sparse_vec['values'][:10]}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
check_qdrant_data()
|
||||
check_sparse_embedder()
|
||||
@@ -1,37 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
简单测试脚本:测试文档里真正有的内容
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
|
||||
from qdrant_client import models
|
||||
from backend.rag_core import QdrantHybridStore, get_sparse_embedder
|
||||
from backend.app.model_services import get_embedding_service
|
||||
|
||||
|
||||
def test_dense_retrieval():
|
||||
"""测试稠密检索"""
|
||||
print("="*70)
|
||||
print("测试稠密检索...")
|
||||
print("="*70)
|
||||
|
||||
embeddings = get_embedding_service()
|
||||
vs = QdrantHybridStore(collection_name="rag_documents", embeddings=embeddings)
|
||||
|
||||
query = "黄双银" # 用文档里真正有的名字查询
|
||||
print(f"\n查询: {query}")
|
||||
|
||||
results = vs.similarity_search(query, k=3)
|
||||
|
||||
print(f"\n找到 {len(results)} 个结果\n")
|
||||
for i, doc in enumerate(results):
|
||||
print(f"--- 结果 {i+1} ---")
|
||||
print(doc.page_content[:200])
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_dense_retrieval()
|
||||
@@ -1,27 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
简单删除 Qdrant 集合
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
from backend.rag_core.client import create_qdrant_client
|
||||
|
||||
|
||||
def delete_collection():
|
||||
print("="*70)
|
||||
print("删除 rag_documents 集合...")
|
||||
print("="*70)
|
||||
|
||||
client = create_qdrant_client()
|
||||
|
||||
try:
|
||||
client.delete_collection("rag_documents")
|
||||
print("✅ 删除成功")
|
||||
except Exception as e:
|
||||
print(f"⚠️ 删除失败: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
delete_collection()
|
||||
@@ -1,150 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
简单测试脚本:检查 Qdrant 内容,测试各种检索方式
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
|
||||
from qdrant_client import models
|
||||
from backend.rag_core import QdrantHybridStore, get_sparse_embedder
|
||||
from backend.app.model_services import get_embedding_service
|
||||
|
||||
|
||||
def check_qdrant_content():
|
||||
"""检查 Qdrant 里的内容"""
|
||||
print("="*70)
|
||||
print("检查 Qdrant 内容...")
|
||||
print("="*70)
|
||||
|
||||
embeddings = get_embedding_service()
|
||||
vs = QdrantHybridStore(collection_name="rag_documents", embeddings=embeddings)
|
||||
client = vs.get_qdrant_client()
|
||||
|
||||
# 滚动获取前 5 个点
|
||||
points, _ = client.scroll(
|
||||
collection_name="rag_documents",
|
||||
limit=5,
|
||||
with_payload=True,
|
||||
with_vectors=False
|
||||
)
|
||||
|
||||
print(f"\n找到 {len(points)} 个文档\n")
|
||||
for i, point in enumerate(points):
|
||||
print(f"--- 文档 {i+1} ---")
|
||||
print(f"ID: {point.id}")
|
||||
print(f"Payload 键: {list(point.payload.keys())}")
|
||||
|
||||
# 打印完整 payload
|
||||
for k, v in point.payload.items():
|
||||
if isinstance(v, str) and len(v) > 150:
|
||||
v = v[:150] + "..."
|
||||
print(f" {k}: {v}")
|
||||
print()
|
||||
|
||||
|
||||
def test_dense_retrieval():
|
||||
"""测试稠密检索"""
|
||||
print("="*70)
|
||||
print("测试稠密检索...")
|
||||
print("="*70)
|
||||
|
||||
embeddings = get_embedding_service()
|
||||
vs = QdrantHybridStore(collection_name="rag_documents", embeddings=embeddings)
|
||||
|
||||
query = "蚂蚁" # 用中文查询
|
||||
print(f"\n查询: {query}")
|
||||
|
||||
results = vs.similarity_search(query, k=3)
|
||||
|
||||
print(f"\n找到 {len(results)} 个结果\n")
|
||||
for i, doc in enumerate(results):
|
||||
print(f"--- 结果 {i+1} ---")
|
||||
print(doc.page_content[:200])
|
||||
print()
|
||||
|
||||
|
||||
def test_sparse_retrieval():
|
||||
"""测试稀疏检索"""
|
||||
print("="*70)
|
||||
print("测试稀疏检索(BM25)...")
|
||||
print("="*70)
|
||||
|
||||
embeddings = get_embedding_service()
|
||||
vs = QdrantHybridStore(collection_name="rag_documents", embeddings=embeddings)
|
||||
client = vs.get_qdrant_client()
|
||||
sparse_embedder = get_sparse_embedder()
|
||||
|
||||
query = "冬天"
|
||||
print(f"\n查询: {query}")
|
||||
|
||||
sparse_query = sparse_embedder.embed_query(query)
|
||||
sparse_vec = models.SparseVector(
|
||||
indices=sparse_query["indices"],
|
||||
values=sparse_query["values"]
|
||||
)
|
||||
|
||||
response = client.query_points(
|
||||
collection_name="rag_documents",
|
||||
query=sparse_vec,
|
||||
using="sparse",
|
||||
limit=3,
|
||||
with_payload=True
|
||||
)
|
||||
|
||||
print(f"\n找到 {len(response.points)} 个结果\n")
|
||||
for i, point in enumerate(response.points):
|
||||
print(f"--- 结果 {i+1} ---")
|
||||
print(f"分数: {point.score:.4f}")
|
||||
text = point.payload.get("page_content", point.payload.get("text", ""))
|
||||
print(text[:200])
|
||||
print()
|
||||
|
||||
|
||||
def test_hybrid_retrieval():
|
||||
"""测试混合检索"""
|
||||
print("="*70)
|
||||
print("测试混合检索(稠密+稀疏 RRF 融合)...")
|
||||
print("="*70)
|
||||
|
||||
embeddings = get_embedding_service()
|
||||
vs = QdrantHybridStore(collection_name="rag_documents", embeddings=embeddings)
|
||||
client = vs.get_qdrant_client()
|
||||
sparse_embedder = get_sparse_embedder()
|
||||
|
||||
query = "蚂蚁和蚱蜢"
|
||||
print(f"\n查询: {query}")
|
||||
|
||||
dense_query = embeddings.embed_query(query)
|
||||
sparse_query = sparse_embedder.embed_query(query)
|
||||
sparse_vec = models.SparseVector(
|
||||
indices=sparse_query["indices"],
|
||||
values=sparse_query["values"]
|
||||
)
|
||||
|
||||
response = client.query_points(
|
||||
collection_name="rag_documents",
|
||||
prefetch=[
|
||||
models.Prefetch(query=dense_query, using="dense", limit=3),
|
||||
models.Prefetch(query=sparse_vec, using="sparse", limit=3)
|
||||
],
|
||||
query=models.FusionQuery(fusion=models.Fusion.RRF),
|
||||
limit=3,
|
||||
with_payload=True
|
||||
)
|
||||
|
||||
print(f"\n找到 {len(response.points)} 个结果\n")
|
||||
for i, point in enumerate(response.points):
|
||||
print(f"--- 结果 {i+1} ---")
|
||||
print(f"分数: {point.score:.4f}")
|
||||
text = point.payload.get("page_content", point.payload.get("text", ""))
|
||||
print(text[:200])
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
check_qdrant_content()
|
||||
test_dense_retrieval()
|
||||
test_sparse_retrieval()
|
||||
test_hybrid_retrieval()
|
||||
@@ -1,303 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
完整后端测试 - 验证 Agent 所有功能
|
||||
包括:短期记忆、长期记忆、工具调用、流式对话、历史查询
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 加载环境变量
|
||||
project_root = os.path.join(os.path.dirname(__file__), "..")
|
||||
load_dotenv(os.path.join(project_root, ".env"))
|
||||
|
||||
from backend.app.config import DB_URI
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
from backend.app.agent.agent_service import AIAgentService
|
||||
from backend.app.agent.history import ThreadHistoryService
|
||||
from backend.app.logger import info, warning, error
|
||||
|
||||
# PostgreSQL 连接字符串
|
||||
|
||||
async def print_section(title):
|
||||
"""打印测试区块标题"""
|
||||
print("\n" + "=" * 70)
|
||||
print(f" {title}")
|
||||
print("=" * 70)
|
||||
|
||||
async def test_short_term_memory(agent_service):
|
||||
"""测试短期记忆(同一 thread_id 继续对话)"""
|
||||
await print_section("测试 1: 短期记忆(Short-term Memory)")
|
||||
|
||||
thread_id = str(uuid.uuid4())
|
||||
user_id = "test_user_memory"
|
||||
|
||||
print(f"\n使用 thread_id: {thread_id[:8]}...")
|
||||
print(f"使用 user_id: {user_id}")
|
||||
|
||||
# 第一轮对话
|
||||
print("\n[第一轮] 发送消息: '我叫张三,今年28岁'")
|
||||
result1 = await agent_service.process_message(
|
||||
"我叫张三,今年28岁", thread_id, "local", user_id
|
||||
)
|
||||
print(f"回复: {result1['reply'][:100]}...")
|
||||
|
||||
# 第二轮对话 - 测试记忆
|
||||
print("\n[第二轮] 发送消息: '我叫什么名字?今年多大?'")
|
||||
result2 = await agent_service.process_message(
|
||||
"我叫什么名字?今年多大?", thread_id, "local", user_id
|
||||
)
|
||||
print(f"回复: {result2['reply']}")
|
||||
|
||||
# 验证记忆是否存在
|
||||
if "张三" in result2['reply'] or "28" in result2['reply']:
|
||||
print("\n✅ 短期记忆测试通过!")
|
||||
return True
|
||||
else:
|
||||
print("\n❌ 短期记忆测试失败!")
|
||||
return False
|
||||
|
||||
async def test_tool_calling(agent_service):
|
||||
"""测试工具调用(RAG 搜索)"""
|
||||
await print_section("测试 2: 工具调用(Tool Calling)")
|
||||
|
||||
thread_id = str(uuid.uuid4())
|
||||
user_id = "test_user_tools"
|
||||
|
||||
print(f"\n使用 thread_id: {thread_id[:8]}...")
|
||||
print(f"使用 user_id: {user_id}")
|
||||
|
||||
# 发送需要 RAG 搜索的问题
|
||||
print("\n发送消息: '请告诉我,黄双银在魔王大陆的故事?'")
|
||||
result = await agent_service.process_message(
|
||||
"请告诉我,黄双银在魔王大陆的故事?", thread_id, "local", user_id
|
||||
)
|
||||
print(f"回复: {result['reply'][:200]}...")
|
||||
|
||||
# 检查是否调用了 RAG 工具(回复中会有黄双银相关内容)
|
||||
if "黄双银" in result['reply']:
|
||||
print("\n✅ 工具调用测试通过!")
|
||||
return True
|
||||
else:
|
||||
print("\n⚠️ 工具调用测试结果不确定,需要手动验证")
|
||||
return None
|
||||
|
||||
async def test_streaming(agent_service):
|
||||
"""测试流式对话"""
|
||||
await print_section("测试 3: 流式对话(Streaming)")
|
||||
|
||||
thread_id = str(uuid.uuid4())
|
||||
user_id = "test_user_stream"
|
||||
|
||||
print(f"\n使用 thread_id: {thread_id[:8]}...")
|
||||
print(f"使用 user_id: {user_id}")
|
||||
|
||||
print("\n发送消息: '用100字介绍一下AI人工智能' (流式)...")
|
||||
print("流式输出: ", end="", flush=True)
|
||||
|
||||
full_reply = ""
|
||||
chunk_count = 0
|
||||
|
||||
try:
|
||||
async for chunk in agent_service.process_message_stream(
|
||||
"用100字介绍一下AI人工智能", thread_id, "local", user_id
|
||||
):
|
||||
chunk_count += 1
|
||||
if chunk.get("type") == "llm_token":
|
||||
token = chunk.get("token", "")
|
||||
print(token, end="", flush=True)
|
||||
full_reply += token
|
||||
elif chunk.get("type") == "state_update":
|
||||
pass # 状态更新不显示
|
||||
|
||||
print(f"\n\n共收到 {chunk_count} 个 chunk")
|
||||
print(f"完整回复长度: {len(full_reply)} 字")
|
||||
|
||||
if chunk_count > 0 and len(full_reply) > 10:
|
||||
print("\n✅ 流式对话测试通过!")
|
||||
return True
|
||||
else:
|
||||
print("\n❌ 流式对话测试失败!")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n❌ 流式对话异常: {e}")
|
||||
return False
|
||||
|
||||
async def test_history_service(agent_service, history_service):
|
||||
"""测试历史查询服务"""
|
||||
await print_section("测试 4: 历史查询服务(History Service)")
|
||||
|
||||
user_id = "test_user_history"
|
||||
|
||||
# 先创建几个对话
|
||||
print(f"\n为 user_id={user_id} 创建测试对话...")
|
||||
|
||||
thread_ids = []
|
||||
for i in range(3):
|
||||
thread_id = str(uuid.uuid4())
|
||||
thread_ids.append(thread_id)
|
||||
|
||||
await agent_service.process_message(
|
||||
f"这是第 {i+1} 个测试对话", thread_id, "local", user_id
|
||||
)
|
||||
print(f" 创建线程 {i+1}: {thread_id[:8]}...")
|
||||
|
||||
# 1. 测试获取用户线程列表
|
||||
print("\n[4.1] 测试获取用户线程列表...")
|
||||
threads = await history_service.get_user_threads(user_id, limit=10)
|
||||
print(f" 找到 {len(threads)} 个线程")
|
||||
|
||||
if len(threads) >= 3:
|
||||
print(" ✅ 线程列表查询通过")
|
||||
else:
|
||||
print(" ⚠️ 线程数量少于预期")
|
||||
|
||||
# 2. 测试获取单个线程的消息历史
|
||||
if thread_ids:
|
||||
test_thread_id = thread_ids[0]
|
||||
print(f"\n[4.2] 测试获取线程消息历史 (thread_id={test_thread_id[:8]}...)")
|
||||
messages = await history_service.get_thread_messages(test_thread_id)
|
||||
print(f" 找到 {len(messages)} 条消息")
|
||||
|
||||
if len(messages) >= 2: # 至少有一问一答
|
||||
print(" ✅ 消息历史查询通过")
|
||||
else:
|
||||
print(" ⚠️ 消息数量少于预期")
|
||||
|
||||
# 3. 测试获取线程摘要
|
||||
print(f"\n[4.3] 测试获取线程摘要...")
|
||||
summary = await history_service.get_thread_summary(test_thread_id)
|
||||
print(f" 摘要: {summary.get('summary', '')[:50]}...")
|
||||
print(f" 消息数: {summary.get('message_count', 0)}")
|
||||
|
||||
if summary.get('message_count', 0) > 0:
|
||||
print(" ✅ 线程摘要查询通过")
|
||||
else:
|
||||
print(" ⚠️ 摘要查询结果不确定")
|
||||
|
||||
return len(threads) >= 3
|
||||
|
||||
async def test_long_term_memory(agent_service):
|
||||
"""测试长期记忆(mem0)"""
|
||||
await print_section("测试 5: 长期记忆(Long-term Memory - mem0)")
|
||||
|
||||
thread_id1 = str(uuid.uuid4())
|
||||
thread_id2 = str(uuid.uuid4()) # 不同的线程
|
||||
user_id = "test_user_longterm"
|
||||
|
||||
print(f"\n使用 user_id: {user_id}")
|
||||
print(f"线程 1: {thread_id1[:8]}...")
|
||||
print(f"线程 2: {thread_id2[:8]}...")
|
||||
|
||||
# 在第一个线程中保存信息
|
||||
print("\n[线程 1] 发送消息: '记住,我的宠物名字叫小白,是一只猫'")
|
||||
result1 = await agent_service.process_message(
|
||||
"记住,我的宠物名字叫小白,是一只猫", thread_id1, "local", user_id
|
||||
)
|
||||
print(f"回复: {result1['reply'][:100]}...")
|
||||
|
||||
# 等待一下,让 mem0 保存
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# 在第二个线程中询问(不同的 thread_id)
|
||||
print("\n[线程 2] 发送消息: '我的宠物叫什么名字?是什么动物?'")
|
||||
result2 = await agent_service.process_message(
|
||||
"我的宠物叫什么名字?是什么动物?", thread_id2, "local", user_id
|
||||
)
|
||||
print(f"回复: {result2['reply']}")
|
||||
|
||||
# 验证长期记忆
|
||||
if "小白" in result2['reply'] or "猫" in result2['reply']:
|
||||
print("\n✅ 长期记忆测试通过!")
|
||||
return True
|
||||
else:
|
||||
print("\n⚠️ 长期记忆可能未启用,或需要手动验证")
|
||||
return None
|
||||
|
||||
async def main():
|
||||
"""主测试函数"""
|
||||
print("\n" + "=" * 70)
|
||||
print(" 后端完整功能测试")
|
||||
print("=" * 70)
|
||||
|
||||
results = {}
|
||||
|
||||
try:
|
||||
# 创建数据库连接和服务
|
||||
print("\n正在初始化数据库连接...")
|
||||
async with AsyncPostgresSaver.from_conn_string(DB_URI) as checkpointer:
|
||||
await checkpointer.setup()
|
||||
print("✅ 数据库连接成功")
|
||||
|
||||
# 创建服务实例
|
||||
print("\n正在初始化 Agent 服务...")
|
||||
agent_service = AIAgentService(checkpointer)
|
||||
await agent_service.initialize()
|
||||
print("✅ Agent 服务初始化成功")
|
||||
|
||||
history_service = ThreadHistoryService(checkpointer)
|
||||
print("✅ 历史服务初始化成功")
|
||||
|
||||
print(f"\n可用模型: {list(agent_service.graphs.keys())}")
|
||||
|
||||
# 运行测试
|
||||
results["短期记忆"] = await test_short_term_memory(agent_service)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
results["工具调用"] = await test_tool_calling(agent_service)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
results["流式对话"] = await test_streaming(agent_service)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
results["历史查询"] = await test_history_service(agent_service, history_service)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
results["长期记忆"] = await test_long_term_memory(agent_service)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# 打印总结
|
||||
await print_section("测试总结")
|
||||
print("\n测试结果:")
|
||||
print("-" * 40)
|
||||
|
||||
pass_count = 0
|
||||
fail_count = 0
|
||||
skip_count = 0
|
||||
|
||||
for test_name, result in results.items():
|
||||
if result is True:
|
||||
status = "✅ 通过"
|
||||
pass_count += 1
|
||||
elif result is False:
|
||||
status = "❌ 失败"
|
||||
fail_count += 1
|
||||
else:
|
||||
status = "⚠️ 待验证"
|
||||
skip_count += 1
|
||||
print(f" {test_name:12s}: {status}")
|
||||
|
||||
print("-" * 40)
|
||||
print(f"总计: {len(results)} 个测试")
|
||||
print(f"通过: {pass_count}, 失败: {fail_count}, 待验证: {skip_count}")
|
||||
|
||||
if fail_count == 0:
|
||||
print("\n🎉 所有核心测试通过!")
|
||||
else:
|
||||
print(f"\n⚠️ 有 {fail_count} 个测试失败")
|
||||
|
||||
except Exception as e:
|
||||
error(f"\n❌ 测试运行异常: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return 1
|
||||
|
||||
return 0 if fail_count == 0 else 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit_code = asyncio.run(main())
|
||||
sys.exit(exit_code)
|
||||
@@ -1,63 +0,0 @@
|
||||
"""检查 Qdrant 中存储的向量质量。"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
from dotenv import load_dotenv
|
||||
from qdrant_client import QdrantClient
|
||||
|
||||
# 加载环境变量
|
||||
project_root = os.path.join(os.path.dirname(__file__), "..")
|
||||
load_dotenv(os.path.join(project_root, ".env"))
|
||||
|
||||
from backend.rag_core import LlamaCppEmbedder
|
||||
|
||||
QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333")
|
||||
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
|
||||
COLLECTION_NAME = "rag_documents"
|
||||
|
||||
client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)
|
||||
embedder = LlamaCppEmbedder()
|
||||
|
||||
# 获取样本
|
||||
points, _ = client.scroll(
|
||||
collection_name=COLLECTION_NAME,
|
||||
limit=1,
|
||||
with_vectors=True,
|
||||
with_payload=True,
|
||||
)
|
||||
|
||||
if not points:
|
||||
print(f"集合 '{COLLECTION_NAME}' 为空")
|
||||
exit()
|
||||
|
||||
sample = points[0]
|
||||
raw_vec = sample.vector
|
||||
if isinstance(raw_vec, dict):
|
||||
stored_vec = list(raw_vec.values())[0]
|
||||
elif isinstance(raw_vec, list):
|
||||
stored_vec = raw_vec
|
||||
else:
|
||||
stored_vec = []
|
||||
|
||||
stored_payload = sample.payload or {}
|
||||
stored_text = str(stored_payload.get("page_content", ""))[:200]
|
||||
|
||||
print(f"内容预览:\n{stored_text}...\n")
|
||||
print(f"向量维度: {len(stored_vec)}") # type: ignore
|
||||
print(f"前5个值: {stored_vec[:5]}") # type: ignore
|
||||
print(f"是否全零: {all(v == 0.0 for v in stored_vec)}") # type: ignore
|
||||
|
||||
# 重新编码对比
|
||||
if stored_text:
|
||||
new_vec = embedder.embed_query(stored_text)
|
||||
similarity = np.dot(stored_vec, new_vec) / (np.linalg.norm(stored_vec) * np.linalg.norm(new_vec)) # type: ignore
|
||||
print(f"\n重新编码前5个值: {new_vec[:5]}")
|
||||
print(f"余弦相似度: {similarity:.4f}")
|
||||
|
||||
if similarity < 0.8:
|
||||
print("\n⚠️ 相似度过低,建议删除集合并重建索引")
|
||||
else:
|
||||
print("\n✅ 向量一致")
|
||||
else:
|
||||
print("\n⚠️ 样本无文本内容")
|
||||
@@ -1,60 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
前端快速测试脚本
|
||||
验证前端导入是否正常工作
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
print("=" * 60)
|
||||
print("前端导入测试")
|
||||
print("=" * 60)
|
||||
|
||||
# 测试 1: 直接导入前端模块
|
||||
print("\n[测试 1] 直接导入前端模块...")
|
||||
try:
|
||||
from frontend.src.frontend_main import main
|
||||
print("✅ frontend_main 导入成功")
|
||||
except Exception as e:
|
||||
print(f"❌ 导入失败: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
# 测试 2: 导入配置
|
||||
print("\n[测试 2] 导入配置...")
|
||||
try:
|
||||
from frontend.src.config import config
|
||||
print(f"✅ config 导入成功: page_title={config.page_title}")
|
||||
except Exception as e:
|
||||
print(f"❌ 导入失败: {e}")
|
||||
|
||||
# 测试 3: 导入状态管理
|
||||
print("\n[测试 3] 导入状态管理...")
|
||||
try:
|
||||
from frontend.src.state import AppState
|
||||
print("✅ AppState 导入成功")
|
||||
except Exception as e:
|
||||
print(f"❌ 导入失败: {e}")
|
||||
|
||||
# 测试 4: 导入 API 客户端
|
||||
print("\n[测试 4] 导入 API 客户端...")
|
||||
try:
|
||||
from frontend.src.api_client import api_client
|
||||
print("✅ api_client 导入成功")
|
||||
except Exception as e:
|
||||
print(f"❌ 导入失败: {e}")
|
||||
|
||||
# 测试 5: 导入组件
|
||||
print("\n[测试 5] 导入组件...")
|
||||
try:
|
||||
from frontend.src.components.sidebar import render_sidebar
|
||||
from frontend.src.components.chat_area import render_chat_area
|
||||
from frontend.src.components.info_panel import render_info_panel
|
||||
print("✅ 所有组件导入成功")
|
||||
except Exception as e:
|
||||
print(f"❌ 导入失败: {e}")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("🎉 所有前端导入测试通过!")
|
||||
print("=" * 60)
|
||||
print("\n现在可以使用 ./scripts/start.sh both 启动完整服务")
|
||||
@@ -1,141 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
RAG 系统使用示例(重构版)
|
||||
|
||||
演示:
|
||||
1. 使用 IndexBuilder 获取父子块检索器
|
||||
2. 创建固定流程的 RAGPipeline(多路改写 → RRF融合 → 重排序 → 返回父文档)
|
||||
3. 将流水线封装为 LangChain 工具,供 Agent 调用
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 加载环境变量(Qdrant URL、PostgreSQL 连接等)
|
||||
project_root = os.path.join(os.path.dirname(__file__), "..")
|
||||
load_dotenv(os.path.join(project_root, ".env"))
|
||||
|
||||
from pydantic import SecretStr
|
||||
from langchain_openai import ChatOpenAI
|
||||
from rag_indexer.index_builder import IndexBuilderConfig
|
||||
from rag_indexer.splitters import SplitterType
|
||||
from backend.app.rag.pipeline import RAGPipeline
|
||||
from backend.app.rag.tools import create_rag_tool_sync
|
||||
from backend.rag_core.retriever_factory import create_parent_retriever
|
||||
|
||||
def create_llm():
|
||||
"""创建本地 vLLM 服务 LLM"""
|
||||
vllm_base_url = os.getenv(
|
||||
"VLLM_BASE_URL",
|
||||
"http://127.0.0.1:8081/v1"
|
||||
)
|
||||
|
||||
return ChatOpenAI(
|
||||
base_url=vllm_base_url,
|
||||
api_key=SecretStr(os.getenv("LLAMACPP_API_KEY", "token-abc123")),
|
||||
model="gemma-4-E2B-it",
|
||||
timeout=60.0, # 请求超时时间(秒)
|
||||
max_retries=2, # 失败后自动重试次数
|
||||
streaming=True, # 确保开启流式输出
|
||||
)
|
||||
|
||||
async def demonstrate_full_pipeline():
|
||||
"""
|
||||
完整流水线演示:
|
||||
- 从 IndexBuilder 获取 ParentDocumentRetriever
|
||||
- 创建 RAGPipeline
|
||||
- 执行检索并打印结果
|
||||
"""
|
||||
print("=" * 60)
|
||||
print("演示:固定流程 RAG 检索(多路改写 + RRF + 重排序 + 父文档)")
|
||||
print("=" * 60)
|
||||
|
||||
retriever = create_parent_retriever(collection_name="rag_documents", search_k=5)
|
||||
|
||||
if retriever is None:
|
||||
print("错误:检索器未初始化,请确保索引已构建。")
|
||||
return
|
||||
|
||||
# 3. 创建 LLM 用于查询改写
|
||||
llm = create_llm()
|
||||
|
||||
# 4. 创建 RAGPipeline(固定流程)
|
||||
pipeline = RAGPipeline(
|
||||
retriever=retriever,
|
||||
llm=llm,
|
||||
num_queries=3, # 生成 3 个查询变体
|
||||
rerank_top_n=5, # 最终返回 5 个父文档
|
||||
)
|
||||
|
||||
# 5. 执行检索
|
||||
query = "打虎英雄是谁?"
|
||||
print(f"\n查询: {query}")
|
||||
print("-" * 40)
|
||||
|
||||
try:
|
||||
documents = await pipeline.aretrieve(query)
|
||||
print(f"返回 {len(documents)} 个父文档\n")
|
||||
|
||||
# 打印结果预览
|
||||
for i, doc in enumerate(documents, 1):
|
||||
content_preview = doc.page_content.replace("\n", " ")[:150]
|
||||
source = doc.metadata.get("source", "未知来源")
|
||||
print(f"{i}. 【来源:{source}】")
|
||||
print(f" {content_preview}...\n")
|
||||
|
||||
# 可选:格式化完整上下文
|
||||
# context = pipeline.format_context(documents)
|
||||
# print(context)
|
||||
|
||||
except Exception as e:
|
||||
print(f"检索失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
async def demonstrate_tool_creation():
|
||||
"""
|
||||
演示创建 RAG 工具(供 Agent 使用)
|
||||
"""
|
||||
print("\n" + "=" * 60)
|
||||
print("演示:创建 RAG 工具(供 LangGraph Agent 调用)")
|
||||
print("=" * 60)
|
||||
|
||||
# 1. 获取检索器(同上)
|
||||
config = IndexBuilderConfig(
|
||||
collection_name="rag_documents",
|
||||
splitter_type=SplitterType.PARENT_CHILD,
|
||||
)
|
||||
retriever = create_parent_retriever(collection_name="rag_documents", search_k=5)
|
||||
|
||||
# 2. 创建 LLM
|
||||
llm = create_llm()
|
||||
|
||||
# 3. 创建工具
|
||||
rag_tool = create_rag_tool_sync(
|
||||
retriever=retriever,
|
||||
llm=llm,
|
||||
num_queries=3,
|
||||
rerank_top_n=5,
|
||||
collection_name="rag_documents",
|
||||
)
|
||||
|
||||
print(f"工具名称: {rag_tool.name}")
|
||||
print(f"工具描述: {rag_tool.description[:100]}...")
|
||||
|
||||
# 4. 模拟 Agent 调用工具
|
||||
query = "请告诉我 打虎英雄是谁?"
|
||||
print(f"\n模拟调用: {query}")
|
||||
print("-" * 40)
|
||||
|
||||
result = await rag_tool.ainvoke({"query": query})
|
||||
print(result[:800] + "..." if len(result) > 800 else result)
|
||||
|
||||
async def main():
|
||||
await demonstrate_full_pipeline()
|
||||
await demonstrate_tool_creation()
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -1,305 +1,130 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
测试重构后的 IndexBuilder 和 RAG 检索
|
||||
包括:索引构建、稠密检索、稀疏检索、混合检索、父子文档检索
|
||||
简单的 RAG 检索测试
|
||||
使用 app/rag/retriever 提供的功能
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from rag_indexer.index_builder import IndexBuilder
|
||||
from rag_indexer.splitters import SplitterType
|
||||
|
||||
from backend.rag_core import QdrantHybridStore, get_sparse_embedder
|
||||
from backend.app.model_services import get_embedding_service
|
||||
from qdrant_client import models
|
||||
|
||||
|
||||
async def test_index_builder():
|
||||
"""测试索引构建功能"""
|
||||
print("="*70)
|
||||
print("1. 测试索引构建功能...")
|
||||
print("="*70)
|
||||
|
||||
# 创建 IndexBuilder 实例
|
||||
builder = IndexBuilder(
|
||||
collection_name="rag_documents",
|
||||
splitter_type=SplitterType.PARENT_CHILD,
|
||||
parent_chunk_size=1000,
|
||||
child_chunk_size=200
|
||||
from backend.app.rag.retriever import (
|
||||
create_parent_hybrid_retriever,
|
||||
create_hybrid_retriever
|
||||
)
|
||||
from backend.rag_core import QdrantHybridStore
|
||||
|
||||
# 测试文档路径
|
||||
project_root = os.path.join(os.path.dirname(__file__), "..", "..")
|
||||
test_file = os.path.join(project_root, "data", "user_docs", "doublestory.txt")
|
||||
|
||||
if os.path.exists(test_file):
|
||||
# 构建索引
|
||||
print(f"正在为文件 {test_file} 构建索引...")
|
||||
processed = await builder.build_from_file(test_file)
|
||||
print(f"索引构建完成,处理了 {processed} 个文档")
|
||||
# 统一的测试查询列表
|
||||
TEST_QUERIES = [
|
||||
"黄双银",
|
||||
]
|
||||
|
||||
# 获取集合信息
|
||||
info = builder.get_collection_info()
|
||||
print(f"集合信息: {info}")
|
||||
|
||||
async def test_simple_vector_store_search():
|
||||
"""测试:直接使用 QdrantHybridStore 的 asimilarity_search"""
|
||||
print("="*80)
|
||||
print("测试 1: QdrantHybridStore.asimilarity_search")
|
||||
print("="*80)
|
||||
|
||||
vs = QdrantHybridStore(collection_name="rag_documents")
|
||||
|
||||
for query in TEST_QUERIES:
|
||||
print(f"\n查询: {query}")
|
||||
print("-" * 60)
|
||||
|
||||
docs = await vs.asimilarity_search(query, k=10)
|
||||
|
||||
if docs:
|
||||
print(f"✓ 找到 {len(docs)} 个文档")
|
||||
for i, doc in enumerate(docs, 1):
|
||||
print(f"\n {i}. 来源: {doc.metadata.get('source', 'unknown')}")
|
||||
preview = doc.page_content[:120].strip()
|
||||
if len(doc.page_content) > 120:
|
||||
preview += "..."
|
||||
print(f" 内容: {preview}")
|
||||
else:
|
||||
print(f"测试文件不存在: {test_file}")
|
||||
print("✗ 未找到结果")
|
||||
|
||||
# 关闭资源
|
||||
builder.close()
|
||||
print("\n索引构建测试完成")
|
||||
return processed
|
||||
await vs.close_async_client()
|
||||
print("\n" + "="*80)
|
||||
|
||||
|
||||
def test_dense_retrieval():
|
||||
"""测试稠密检索"""
|
||||
print("\n" + "="*70)
|
||||
print("2. 测试稠密检索...")
|
||||
print("="*70)
|
||||
async def test_hybrid_retriever():
|
||||
"""测试:HybridRetriever(子文档检索)"""
|
||||
print("\n" + "="*80)
|
||||
print("测试 2: HybridRetriever (子文档混合检索)")
|
||||
print("="*80)
|
||||
|
||||
# 获取嵌入服务
|
||||
embeddings = get_embedding_service()
|
||||
|
||||
# 创建向量存储
|
||||
vs = QdrantHybridStore(collection_name="rag_documents", embeddings=embeddings)
|
||||
|
||||
# 测试查询
|
||||
query = "The Ant and the Grasshopper"
|
||||
print(f"查询: {query}")
|
||||
|
||||
results = vs.similarity_search(query, k=3)
|
||||
|
||||
print(f"\n找到 {len(results)} 个结果:")
|
||||
for i, doc in enumerate(results, 1):
|
||||
print(f"\n{i}. (来源: {doc.metadata.get('source', 'unknown')})")
|
||||
print(f" 元数据: {doc.metadata}")
|
||||
content = doc.page_content.strip()
|
||||
if len(content) > 200:
|
||||
content = content[:200] + "..."
|
||||
print(f" 内容: {content}")
|
||||
|
||||
|
||||
def test_sparse_retrieval_simple():
|
||||
"""简单测试稀疏检索"""
|
||||
print("\n" + "="*70)
|
||||
print("3. 测试稀疏检索(BM25)...")
|
||||
print("="*70)
|
||||
|
||||
# 获取嵌入服务和稀疏嵌入器
|
||||
embeddings = get_embedding_service()
|
||||
vs = QdrantHybridStore(collection_name="rag_documents", embeddings=embeddings)
|
||||
client = vs.get_qdrant_client()
|
||||
sparse_embedder = get_sparse_embedder()
|
||||
|
||||
# 测试查询 - 用关键词
|
||||
query = "winter work food"
|
||||
print(f"查询关键词: {query}")
|
||||
|
||||
# 生成稀疏查询向量
|
||||
sparse_query = sparse_embedder.embed_query(query)
|
||||
|
||||
# 包装成 SparseVector 对象
|
||||
sparse_vec = models.SparseVector(
|
||||
indices=sparse_query["indices"],
|
||||
values=sparse_query["values"]
|
||||
)
|
||||
|
||||
# 直接查询稀疏向量
|
||||
response = client.query_points(
|
||||
retriever = create_hybrid_retriever(
|
||||
collection_name="rag_documents",
|
||||
query=sparse_vec,
|
||||
using="sparse",
|
||||
limit=3,
|
||||
with_payload=True
|
||||
search_k=10
|
||||
)
|
||||
|
||||
print(f"\n找到 {len(response.points)} 个结果:")
|
||||
for i, point in enumerate(response.points, 1):
|
||||
print(f"\n{i}. (分数: {point.score:.4f})")
|
||||
text = point.payload.get("text", "")
|
||||
metadata = {k: v for k, v in point.payload.items() if k != "text"}
|
||||
print(f" 元数据: {metadata}")
|
||||
content = text.strip()
|
||||
if len(content) > 200:
|
||||
content = content[:200] + "..."
|
||||
print(f" 内容: {content}")
|
||||
for query in TEST_QUERIES:
|
||||
print(f"\n查询: {query}")
|
||||
print("-" * 60)
|
||||
|
||||
docs = await retriever.ainvoke(query)
|
||||
|
||||
def test_hybrid_retrieval_simple():
|
||||
"""简单测试混合检索(稠密+稀疏 RRF 融合)"""
|
||||
print("\n" + "="*70)
|
||||
print("4. 测试混合检索(稠密+稀疏 RRF 融合)...")
|
||||
print("="*70)
|
||||
|
||||
# 获取嵌入服务和稀疏嵌入器
|
||||
embeddings = get_embedding_service()
|
||||
vs = QdrantHybridStore(collection_name="rag_documents", embeddings=embeddings)
|
||||
client = vs.get_qdrant_client()
|
||||
sparse_embedder = get_sparse_embedder()
|
||||
|
||||
# 测试查询
|
||||
query = "Ant and Grasshopper story"
|
||||
print(f"查询: {query}")
|
||||
|
||||
# 生成双向量
|
||||
dense_query = embeddings.embed_query(query)
|
||||
sparse_query = sparse_embedder.embed_query(query)
|
||||
sparse_vec = models.SparseVector(
|
||||
indices=sparse_query["indices"],
|
||||
values=sparse_query["values"]
|
||||
)
|
||||
|
||||
# 使用 Qdrant 的 query_points 做混合检索
|
||||
response = client.query_points(
|
||||
collection_name="rag_documents",
|
||||
prefetch=[
|
||||
models.Prefetch(
|
||||
query=dense_query,
|
||||
using="dense",
|
||||
limit=3
|
||||
),
|
||||
models.Prefetch(
|
||||
query=sparse_vec,
|
||||
using="sparse",
|
||||
limit=3
|
||||
)
|
||||
],
|
||||
query=models.FusionQuery(fusion=models.Fusion.RRF),
|
||||
limit=3,
|
||||
with_payload=True
|
||||
)
|
||||
|
||||
print(f"\n找到 {len(response.points)} 个结果:")
|
||||
for i, point in enumerate(response.points, 1):
|
||||
print(f"\n{i}. (RRF 融合分数: {point.score:.4f})")
|
||||
text = point.payload.get("text", "")
|
||||
metadata = {k: v for k, v in point.payload.items() if k != "text"}
|
||||
print(f" 元数据: {metadata}")
|
||||
content = text.strip()
|
||||
if len(content) > 200:
|
||||
content = content[:200] + "..."
|
||||
print(f" 内容: {content}")
|
||||
|
||||
|
||||
def test_parent_child_retrieval_simple():
|
||||
"""简单测试父子文档检索"""
|
||||
print("\n" + "="*70)
|
||||
print("5. 测试父子文档混合检索...")
|
||||
print("="*70)
|
||||
|
||||
# 获取嵌入服务和稀疏嵌入器
|
||||
embeddings = get_embedding_service()
|
||||
vs = QdrantHybridStore(collection_name="rag_documents", embeddings=embeddings)
|
||||
client = vs.get_qdrant_client()
|
||||
sparse_embedder = get_sparse_embedder()
|
||||
|
||||
# 测试查询
|
||||
query = "The Ant and the Grasshopper story moral"
|
||||
print(f"查询: {query}")
|
||||
|
||||
# 生成双向量
|
||||
dense_query = embeddings.embed_query(query)
|
||||
sparse_query = sparse_embedder.embed_query(query)
|
||||
sparse_vec = models.SparseVector(
|
||||
indices=sparse_query["indices"],
|
||||
values=sparse_query["values"]
|
||||
)
|
||||
|
||||
# 先做混合检索找到子文档
|
||||
response = client.query_points(
|
||||
collection_name="rag_documents",
|
||||
prefetch=[
|
||||
models.Prefetch(
|
||||
query=dense_query,
|
||||
using="dense",
|
||||
limit=5
|
||||
),
|
||||
models.Prefetch(
|
||||
query=sparse_vec,
|
||||
using="sparse",
|
||||
limit=5
|
||||
)
|
||||
],
|
||||
query=models.FusionQuery(fusion=models.Fusion.RRF),
|
||||
limit=5,
|
||||
with_payload=True
|
||||
)
|
||||
|
||||
# 收集 parent_id
|
||||
parent_score_map = {}
|
||||
child_points = {}
|
||||
for point in response.points:
|
||||
parent_id = point.payload.get("parent_id", point.id)
|
||||
score = point.score
|
||||
if parent_id not in parent_score_map or score > parent_score_map[parent_id]:
|
||||
parent_score_map[parent_id] = score
|
||||
child_points[parent_id] = point
|
||||
|
||||
parent_ids = list(parent_score_map.keys())
|
||||
|
||||
print(f"\n找到 {len(parent_ids)} 个不同的 parent_id:")
|
||||
|
||||
# 查找父文档
|
||||
if parent_ids:
|
||||
parent_docs = client.retrieve(
|
||||
collection_name="rag_documents",
|
||||
ids=parent_ids,
|
||||
with_payload=True
|
||||
)
|
||||
|
||||
found_parent_ids = {p.id for p in parent_docs}
|
||||
|
||||
# 准备结果列表
|
||||
results = []
|
||||
for p in parent_docs:
|
||||
score = parent_score_map[p.id]
|
||||
results.append((p, score))
|
||||
|
||||
# 处理没找到父文档的情况 - 用子文档代替
|
||||
missing = set(parent_ids) - found_parent_ids
|
||||
for parent_id in missing:
|
||||
child_point = child_points[parent_id]
|
||||
print(f"\n注意: parent_id {parent_id} 未找到,使用子文档代替")
|
||||
results.append((child_point, parent_score_map[parent_id]))
|
||||
|
||||
# 按分数排序
|
||||
results.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
# 显示
|
||||
print(f"\n共 {len(results)} 个结果(去重后):")
|
||||
for i, (point, score) in enumerate(results[:3], 1):
|
||||
print(f"\n{i}. (分数: {score:.4f})")
|
||||
text = point.payload.get("text", "")
|
||||
metadata = {k: v for k, v in point.payload.items() if k != "text"}
|
||||
print(f" 元数据: {metadata}")
|
||||
content = text.strip()
|
||||
if len(content) > 400:
|
||||
content = content[:400] + "..."
|
||||
print(f" 内容: {content}")
|
||||
if docs:
|
||||
print(f"✓ 找到 {len(docs)} 个子文档")
|
||||
for i, doc in enumerate(docs, 1):
|
||||
print(f"\n {i}. parent_id: {doc.metadata.get('parent_id', 'none')}")
|
||||
preview = doc.page_content[:100].strip()
|
||||
if len(doc.page_content) > 100:
|
||||
preview += "..."
|
||||
print(f" 内容: {preview}")
|
||||
else:
|
||||
print("\n未找到结果")
|
||||
print("✗ 未找到结果")
|
||||
|
||||
print("\n" + "="*80)
|
||||
|
||||
|
||||
async def test_parent_hybrid_retriever():
|
||||
"""测试:ParentHybridRetriever(父子文档混合检索)"""
|
||||
print("\n" + "="*80)
|
||||
print("测试 3: ParentHybridRetriever (父子文档混合检索)")
|
||||
print("="*80)
|
||||
|
||||
retriever = create_parent_hybrid_retriever(
|
||||
collection_name="rag_documents",
|
||||
search_k=10
|
||||
)
|
||||
|
||||
for query in TEST_QUERIES:
|
||||
print(f"\n查询: {query}")
|
||||
print("-" * 60)
|
||||
|
||||
docs = await retriever.ainvoke(query)
|
||||
|
||||
if docs:
|
||||
print(f"✓ 找到 {len(docs)} 个父文档")
|
||||
for i, doc in enumerate(docs, 1):
|
||||
print(f"\n {i}. 来源: {doc.metadata.get('source', 'unknown')}")
|
||||
preview = doc.page_content[:150].strip()
|
||||
if len(doc.page_content) > 150:
|
||||
preview += "..."
|
||||
print(f" 内容:\n {preview}")
|
||||
else:
|
||||
print("✗ 未找到结果")
|
||||
|
||||
print("\n" + "="*80)
|
||||
|
||||
|
||||
async def main():
|
||||
"""主测试函数"""
|
||||
# 1. 先构建索引
|
||||
await test_index_builder()
|
||||
print("\n" + "="*80)
|
||||
print("RAG 检索功能测试")
|
||||
print("="*80)
|
||||
|
||||
# 2. 测试稠密检索
|
||||
test_dense_retrieval()
|
||||
# 测试 1: 直接使用 vector store
|
||||
await test_simple_vector_store_search()
|
||||
|
||||
# 3. 测试稀疏检索
|
||||
test_sparse_retrieval_simple()
|
||||
# 测试 2: HybridRetriever
|
||||
await test_hybrid_retriever()
|
||||
|
||||
# 4. 测试混合检索
|
||||
test_hybrid_retrieval_simple()
|
||||
# 测试 3: ParentHybridRetriever
|
||||
await test_parent_hybrid_retriever()
|
||||
|
||||
# 5. 测试父子文档检索
|
||||
test_parent_child_retrieval_simple()
|
||||
|
||||
print("\n" + "="*70)
|
||||
print("所有测试完成!")
|
||||
print("="*70)
|
||||
print("\n🎉 所有测试完成!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
145
tools/test/test_rag_pipeline.py
Normal file
145
tools/test/test_rag_pipeline.py
Normal file
@@ -0,0 +1,145 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
完整的 RAG Pipeline 测试
|
||||
测试从查询改写 → 检索 → RRF融合 → 重排序 → 格式化输出的整个流程
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from backend.app.rag.pipeline import RAGPipeline, create_rag_pipeline
|
||||
from backend.app.rag.tools import create_rag_tool
|
||||
|
||||
|
||||
async def test_rag_pipeline_direct():
|
||||
"""测试 1: 直接使用 RAGPipeline(默认用小模型做查询改写)"""
|
||||
print("="*80)
|
||||
print("测试 1: 直接使用 RAGPipeline(默认用小模型做查询改写)")
|
||||
print("="*80)
|
||||
|
||||
# 创建 pipeline(默认用小模型)
|
||||
pipeline = create_rag_pipeline(
|
||||
collection_name="rag_documents",
|
||||
num_queries=3,
|
||||
rerank_top_n=5
|
||||
)
|
||||
|
||||
query = "黄双银的经历"
|
||||
|
||||
print(f"\n用户查询: {query}")
|
||||
print("-" * 80)
|
||||
|
||||
# 执行检索
|
||||
docs = await pipeline.aretrieve(query)
|
||||
|
||||
if docs:
|
||||
print(f"\n✓ 找到 {len(docs)} 个相关文档")
|
||||
print("-" * 80)
|
||||
|
||||
for i, doc in enumerate(docs, 1):
|
||||
print(f"\n{i}. 来源: {doc.metadata.get('source', 'unknown')}")
|
||||
print(f" 内容:\n{doc.page_content}")
|
||||
print("-" * 80)
|
||||
|
||||
# 格式化输出
|
||||
print("\n" + "="*80)
|
||||
print("格式化后的上下文:")
|
||||
print("="*80)
|
||||
formatted_context = pipeline.format_context(docs)
|
||||
print(formatted_context)
|
||||
else:
|
||||
print("\n✗ 未找到相关文档")
|
||||
|
||||
print("\n" + "="*80)
|
||||
|
||||
|
||||
async def test_rag_tool():
|
||||
"""测试 2: 使用 RAG Tool(默认用小模型做查询改写)"""
|
||||
print("\n"+"="*80)
|
||||
print("测试 2: 使用 RAG Tool(默认用小模型做查询改写)")
|
||||
print("="*80)
|
||||
|
||||
# 创建 tool(默认用小模型)
|
||||
rag_tool = create_rag_tool(
|
||||
collection_name="rag_documents",
|
||||
num_queries=3,
|
||||
rerank_top_n=5
|
||||
)
|
||||
|
||||
query = "黄双银的经历"
|
||||
|
||||
print(f"\n用户查询: {query}")
|
||||
print("-" * 80)
|
||||
|
||||
# 使用 tool (异步调用 ainvoke)
|
||||
result = await rag_tool.ainvoke(query)
|
||||
|
||||
print("\nTool 返回结果:")
|
||||
print("="*80)
|
||||
print(result)
|
||||
print("="*80)
|
||||
|
||||
|
||||
async def test_custom_pipeline():
|
||||
"""测试 3: 自定义参数的 RAGPipeline(默认用小模型)"""
|
||||
print("\n"+"="*80)
|
||||
print("测试 3: 自定义参数的 RAGPipeline(默认用小模型)")
|
||||
print("="*80)
|
||||
|
||||
# 自定义参数(默认用小模型)
|
||||
pipeline = RAGPipeline(
|
||||
collection_name="rag_documents",
|
||||
num_queries=2, # 只生成 2 个查询变体
|
||||
rerank_top_n=3 # 只返回前 3 个最相关文档
|
||||
)
|
||||
|
||||
query = "黄双银的经历"
|
||||
|
||||
print(f"\n用户查询: {query}")
|
||||
print(f"配置: num_queries=2, rerank_top_n=3")
|
||||
print("-" * 80)
|
||||
|
||||
docs = await pipeline.aretrieve(query)
|
||||
|
||||
if docs:
|
||||
print(f"\n✓ 找到 {len(docs)} 个相关文档")
|
||||
print("-" * 80)
|
||||
|
||||
for i, doc in enumerate(docs, 1):
|
||||
print(f"\n{i}. 来源: {doc.metadata.get('source', 'unknown')}")
|
||||
preview = doc.page_content[:200].strip()
|
||||
if len(doc.page_content) > 200:
|
||||
preview += "..."
|
||||
print(f" 内容预览: {preview}")
|
||||
|
||||
print("\n" + "="*80)
|
||||
print("格式化后的上下文:")
|
||||
print("="*80)
|
||||
print(pipeline.format_context(docs))
|
||||
else:
|
||||
print("\n✗ 未找到相关文档")
|
||||
|
||||
print("\n" + "="*80)
|
||||
|
||||
|
||||
async def main():
|
||||
"""主测试函数"""
|
||||
print("\n" + "="*80)
|
||||
print("完整 RAG Pipeline 测试")
|
||||
print("查询: '黄双银的经历'")
|
||||
print("="*80)
|
||||
|
||||
# 测试 1: 直接使用 pipeline
|
||||
await test_rag_pipeline_direct()
|
||||
|
||||
# 测试 2: 使用 tool
|
||||
await test_rag_tool()
|
||||
|
||||
# 测试 3: 自定义参数
|
||||
await test_custom_pipeline()
|
||||
|
||||
print("\n" + "="*80)
|
||||
print("🎉 所有 RAG Pipeline 测试完成!")
|
||||
print("="*80)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -1,51 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
测试 app/rag/retriever.py 里的混合检索函数
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
|
||||
from backend.app.rag.retriever import create_hybrid_retriever, create_parent_hybrid_retriever
|
||||
|
||||
|
||||
def test_hybrid_retriever():
|
||||
"""测试混合检索器"""
|
||||
print("="*70)
|
||||
print("测试 HybridRetriever...")
|
||||
print("="*70)
|
||||
|
||||
retriever = create_hybrid_retriever(collection_name="rag_documents", search_k=3)
|
||||
results = retriever.invoke("黄双银")
|
||||
|
||||
print(f"\n找到 {len(results)} 个结果\n")
|
||||
for i, doc in enumerate(results):
|
||||
print(f"--- 结果 {i+1} ---")
|
||||
print(doc.page_content[:200])
|
||||
print()
|
||||
|
||||
|
||||
def test_parent_hybrid_retriever():
|
||||
"""测试父子混合检索器"""
|
||||
print("\n" + "="*70)
|
||||
print("测试 ParentHybridRetriever...")
|
||||
print("="*70)
|
||||
|
||||
retriever = create_parent_hybrid_retriever(
|
||||
collection_name="rag_documents",
|
||||
search_k=3,
|
||||
use_docstore=False
|
||||
)
|
||||
results = retriever.invoke("黄双银")
|
||||
|
||||
print(f"\n找到 {len(results)} 个结果\n")
|
||||
for i, doc in enumerate(results):
|
||||
print(f"--- 结果 {i+1} ---")
|
||||
print(doc.page_content[:300])
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_hybrid_retriever()
|
||||
test_parent_hybrid_retriever()
|
||||
Reference in New Issue
Block a user