diff --git a/backend/app/config.py b/backend/app/config.py index 7f7334f..9b46560 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -37,9 +37,9 @@ def _get_bool(key: str) -> bool | None: # ========== 第三方 API 密钥 ========== -ZHIPUAI_API_KEY=_get_str("ZHIPUAI_API_KEY") -DEEPSEEK_API_KEY=_get_str("DEEPSEEK_API_KEY") -SILICONFLOW_API_KEY=_get_str("SILICONFLOW_API_KEY") +ZHIPUAI_API_KEY = _get_str("ZHIPUAI_API_KEY") +DEEPSEEK_API_KEY = _get_str("DEEPSEEK_API_KEY") +SILICONFLOW_API_KEY = _get_str("SILICONFLOW_API_KEY") # ========== 智谱 API 配置 ========== @@ -69,7 +69,7 @@ LOCAL_MODEL_NAME = _get_str("LOCAL_MODEL_NAME") or "gemma-4-E4B-it" # ========== llama.cpp 服务配置(URL + API密钥 配对) ========== # 主 LLM 服务 VLLM_BASE_URL = _get_str("VLLM_BASE_URL") -LLM_API_KEY = _get_str("LLAMACPP_API_KEY") +LLM_API_KEY = _get_str("LLM_API_KEY") # Embedding 服务 (用于 Mem0 的向量化) LLAMACPP_EMBEDDING_URL = _get_str("LLAMACPP_EMBEDDING_URL") @@ -78,6 +78,26 @@ LLAMACPP_API_KEY = _get_str("LLAMACPP_API_KEY") # Reranker 服务 LLAMACPP_RERANKER_URL = _get_str("LLAMACPP_RERANKER_URL") +# ========== 小模型配置(查询改写、意图分类等简单任务) ========== +# 默认复用大模型配置,后续可单独配置 +# 本地小模型(默认复用 VLLM 配置) +SMALL_VLLM_BASE_URL = _get_str("SMALL_VLLM_BASE_URL") +SMALL_LLM_API_KEY = _get_str("SMALL_LLM_API_KEY") +SMALL_LOCAL_MODEL_NAME = _get_str("SMALL_LOCAL_MODEL_NAME") or LOCAL_MODEL_NAME +# 如果小模型没单独配置,用大模型的配置 +if not SMALL_VLLM_BASE_URL: + SMALL_VLLM_BASE_URL = VLLM_BASE_URL +if not SMALL_LLM_API_KEY: + SMALL_LLM_API_KEY = LLM_API_KEY + +# DeepSeek 小模型(默认复用 DeepSeek 配置) +SMALL_DEEPSEEK_API_KEY = _get_str("SMALL_DEEPSEEK_API_KEY") +SMALL_DEEPSEEK_MODEL = _get_str("SMALL_DEEPSEEK_MODEL") or "deepseek-chat" +SMALL_DEEPSEEK_API_BASE = _get_str("SMALL_DEEPSEEK_API_BASE") or "https://api.deepseek.com" +# 如果小模型没单独配置,用大模型的配置 +if not SMALL_DEEPSEEK_API_KEY: + SMALL_DEEPSEEK_API_KEY = DEEPSEEK_API_KEY + # ========== Qdrant 向量数据库配置(URL + API密钥 配对) ========== QDRANT_URL = _get_str("QDRANT_URL") @@ -114,4 +134,4 @@ ENABLE_GRAPH_TRACE = _get_bool("ENABLE_GRAPH_TRACE") # ========== 日志配置 ========== LOG_LEVEL = _get_str("LOG_LEVEL") -DEBUG = _get_bool("DEBUG") +DEBUG = _get_bool("DEBUG") \ No newline at end of file diff --git a/backend/app/core/intent_classifier.py b/backend/app/core/intent_classifier.py index 8984fbf..b6c0493 100644 --- a/backend/app/core/intent_classifier.py +++ b/backend/app/core/intent_classifier.py @@ -6,7 +6,7 @@ from typing import Optional, Dict, Any import sys import os -from backend.app.model_services.chat_services import get_chat_service +from backend.app.model_services.chat_services import get_small_llm_service class IntentType(Enum): @@ -33,7 +33,7 @@ class IntentClassifier: """意图分类器""" def __init__(self): - self.llm = get_chat_service() + self.llm = get_small_llm_service() self._intent_examples = self._build_examples() def _build_examples(self) -> str: diff --git a/backend/app/main_graph/nodes/rag_nodes.py b/backend/app/main_graph/nodes/rag_nodes.py index ea3889c..2a4c749 100644 --- a/backend/app/main_graph/nodes/rag_nodes.py +++ b/backend/app/main_graph/nodes/rag_nodes.py @@ -19,7 +19,7 @@ from app.main_graph.utils.retry_utils import ( ) # 真正导入和利用已有 RAG 代码 -from app.rag.tools import create_rag_tool_sync +from app.rag.tools import create_rag_tool from app.rag.pipeline import RAGPipeline diff --git a/backend/app/main_graph/utils/rag_initializer.py b/backend/app/main_graph/utils/rag_initializer.py index f83ccca..6707a6e 100644 --- a/backend/app/main_graph/utils/rag_initializer.py +++ b/backend/app/main_graph/utils/rag_initializer.py @@ -1,6 +1,6 @@ # app/rag_initializer.py -from app.rag.tools import create_rag_tool_sync, create_rag_tool_async -from rag_core import create_parent_retriever +from app.rag.tools import create_rag_tool +from app.rag.retriever import create_parent_hybrid_retriever from app.model_services import get_embedding_service from app.logger import info, warning @@ -10,18 +10,18 @@ async def init_rag_tool(local_llm_creator): info("🔄 正在初始化 RAG 检索系统...") # 使用统一的嵌入服务获取接口 embeddings = get_embedding_service() - retriever = create_parent_retriever( + retriever = create_parent_hybrid_retriever( collection_name="rag_documents", search_k=5, embeddings=embeddings ) rewrite_llm = local_llm_creator() - rag_tool = create_rag_tool_async( + rag_tool = create_rag_tool( retriever, rewrite_llm, num_queries=3, rerank_top_n=5 ) - info("✅ RAG 检索工具初始化成功(异步版本)") + info("✅ RAG 检索工具初始化成功(全异步版本)") return rag_tool except Exception as e: warning(f"⚠️ RAG 检索工具初始化失败: {e}") - return None \ No newline at end of file + return None diff --git a/backend/app/model_services/__init__.py b/backend/app/model_services/__init__.py index d0d8e85..5d7f173 100644 --- a/backend/app/model_services/__init__.py +++ b/backend/app/model_services/__init__.py @@ -6,9 +6,11 @@ from .embedding_services import get_embedding_service from .rerank_services import get_rerank_service, BaseRerankService +from .chat_services import get_small_llm_service __all__ = [ "get_embedding_service", "get_rerank_service", + "get_small_llm_service", "BaseRerankService" ] diff --git a/backend/app/model_services/chat_services.py b/backend/app/model_services/chat_services.py index 4f51075..f75dd7c 100644 --- a/backend/app/model_services/chat_services.py +++ b/backend/app/model_services/chat_services.py @@ -219,51 +219,98 @@ class DeepSeekChatProvider(BaseServiceProvider[BaseChatModel]): # ========== 轻量级模型 Provider ========== -class ZhipuSmallModelProvider(BaseServiceProvider[BaseChatModel]): +class LocalSmallModelProvider(BaseServiceProvider[BaseChatModel]): """ - 智谱 AI 轻量级模型服务提供者(用于意图分类等简单任务) - 使用 glm-5.1-flash 或其他小模型 + 本地轻量级模型服务提供者(用于查询改写、意图分类等简单任务) + 使用小模型独立配置 """ - def __init__(self, model: str = "glm-5.1-flash"): - super().__init__("zhipu_small") - self._model = model + def __init__(self, model: str = None): + from app.config import SMALL_LOCAL_MODEL_NAME, SMALL_VLLM_BASE_URL, SMALL_LLM_API_KEY + super().__init__("local_small") + self._model = model or SMALL_LOCAL_MODEL_NAME + self._base_url = SMALL_VLLM_BASE_URL + self._api_key = SMALL_LLM_API_KEY def is_available(self) -> bool: - """检查智谱轻量模型服务是否可用""" - if not ZHIPUAI_API_KEY: - logger.warning("ZHIPUAI_API_KEY 未配置,轻量模型不可用") + """检查本地小模型服务是否可用""" + if not self._base_url: + logger.warning("SMALL_VLLM_BASE_URL 未配置,本地小模型不可用") + return False + + try: + # 先测试主机名能否解析 + import httpx + from urllib.parse import urlparse + + parsed_url = urlparse(self._base_url) + host = parsed_url.hostname + port = parsed_url.port or (80 if parsed_url.scheme == 'http' else 443) + + # 测试能否建立 TCP 连接(快速失败) + import socket + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(2.0) + try: + sock.connect((host, port)) + sock.close() + except Exception as e: + logger.warning(f"本地小模型服务无法连接: {host}:{port} - {e}") + return False + + # 再尝试调用简单的 API + client = httpx.Client(base_url=self._base_url.rstrip('/'), timeout=5.0) + headers = {} + if self._api_key: + headers["Authorization"] = f"Bearer {self._api_key}" + + try: + response = client.get("/models", headers=headers) + if response.status_code == 200: + logger.info(f"本地小模型服务可用: {self._model}") + return True + except Exception: + pass + + logger.warning(f"本地小模型服务响应异常") + return False + except Exception as e: + logger.warning(f"本地小模型服务不可用: {e}") return False - logger.info(f"智谱轻量模型配置正确: {self._model}") - return True def get_service(self) -> BaseChatModel: - """获取智谱轻量模型服务""" + """获取本地小模型服务""" if self._service_instance is None: - from langchain_community.chat_models import ChatZhipuAI - self._service_instance = ChatZhipuAI( + from langchain_openai import ChatOpenAI + from pydantic import SecretStr + + self._service_instance = ChatOpenAI( + base_url=self._base_url, + api_key=SecretStr(self._api_key) if self._api_key else SecretStr(""), model=self._model, - api_key=ZHIPUAI_API_KEY, - temperature=0.1, - max_tokens=2048, timeout=30.0, max_retries=2, - streaming=False + streaming=False, ) return self._service_instance + class DeepSeekSmallModelProvider(BaseServiceProvider[BaseChatModel]): """ - DeepSeek 轻量级模型服务提供者(备选) + DeepSeek 轻量级模型服务提供者(用于查询改写、意图分类等简单任务) + 使用小模型独立配置 """ - def __init__(self, model: str = "deepseek-chat"): + def __init__(self, model: str = None): + from app.config import SMALL_DEEPSEEK_MODEL, SMALL_DEEPSEEK_API_KEY, SMALL_DEEPSEEK_API_BASE super().__init__("deepseek_small") - self._model = model + self._model = model or SMALL_DEEPSEEK_MODEL + self._api_key = SMALL_DEEPSEEK_API_KEY + self._api_base = SMALL_DEEPSEEK_API_BASE def is_available(self) -> bool: - if not DEEPSEEK_API_KEY: - logger.warning("DEEPSEEK_API_KEY 未配置") + if not self._api_key: + logger.warning("SMALL_DEEPSEEK_API_KEY 未配置") return False logger.info(f"DeepSeek 轻量模型配置正确: {self._model}") return True @@ -274,8 +321,8 @@ class DeepSeekSmallModelProvider(BaseServiceProvider[BaseChatModel]): from pydantic import SecretStr self._service_instance = ChatOpenAI( - base_url="https://api.deepseek.com", - api_key=SecretStr(DEEPSEEK_API_KEY), + base_url=self._api_base, + api_key=SecretStr(self._api_key), model=self._model, temperature=0.1, max_tokens=2048, @@ -339,20 +386,17 @@ def get_all_chat_services() -> Dict[str, BaseChatModel]: def get_small_llm_service() -> BaseChatModel: """ - 获取轻量级大模型服务(用于意图分类等简单任务) - 优先顺序: zhipu_small -> deepseek_small -> (降级到 get_chat_service) + 获取轻量级大模型服务(用于查询改写、意图分类等简单任务) + 优先顺序: 本地模型 -> DeepSeek 小模型 + ⚠️ 注意:小模型任务不降级到大模型,避免不必要的 token 消耗! Returns: BaseChatModel: LangChain 兼容的 ChatModel 实例 """ def _create_small_chain(): - primary = ZhipuSmallModelProvider() + primary = LocalSmallModelProvider() fallbacks = [DeepSeekSmallModelProvider()] return FallbackServiceChain(primary, fallbacks) - try: - chain = SingletonServiceManager.get_or_create("small_llm_chain", _create_small_chain) - return chain.get_available_service() - except Exception as e: - logger.warning(f"轻量模型初始化失败,降级到默认大模型: {e}") - return get_chat_service() + chain = SingletonServiceManager.get_or_create("small_llm_chain", _create_small_chain) + return chain.get_available_service() diff --git a/backend/app/rag/__init__.py b/backend/app/rag/__init__.py index ca1911f..04e9462 100644 --- a/backend/app/rag/__init__.py +++ b/backend/app/rag/__init__.py @@ -42,7 +42,7 @@ from .rerank import DocumentReranker, create_document_reranker from .query_transform import MultiQueryGenerator from .fusion import reciprocal_rank_fusion from .pipeline import RAGPipeline -from .tools import create_rag_tool_sync +from .tools import create_rag_tool __all__ = [ @@ -64,5 +64,5 @@ __all__ = [ "RAGPipeline", # 工具创建(供 Agent 使用) - "create_rag_tool_sync", + "create_rag_tool", ] \ No newline at end of file diff --git a/backend/app/rag/pipeline.py b/backend/app/rag/pipeline.py index a66eb78..4853000 100644 --- a/backend/app/rag/pipeline.py +++ b/backend/app/rag/pipeline.py @@ -13,7 +13,7 @@ from typing import List, Optional from langchain_core.documents import Document from langchain_core.language_models import BaseLanguageModel -from app.model_services import get_rerank_service +from app.model_services import get_rerank_service, get_small_llm_service from app.rag.rerank import create_document_reranker from app.rag.query_transform import MultiQueryGenerator from app.rag.fusion import reciprocal_rank_fusion @@ -31,7 +31,7 @@ class RAGPipeline: def __init__( self, retriever=None, - llm: Optional[BaseLanguageModel] = None, + llm: Optional[BaseLanguageModel] = "default_small", num_queries: int = 3, rerank_top_n: int = 5, collection_name: str = "rag_documents", @@ -41,6 +41,9 @@ class RAGPipeline: retriever: 基础检索器对象,需实现 ainvoke(query) 异步方法。 如果不提供,会自动创建默认的父子文档混合检索器。 llm: 用于生成多路查询的语言模型。 + - "default_small": (默认) 使用小模型(本地 + DeepSeek) + - None / False: 不做查询改写 + - BaseLanguageModel 实例: 自定义模型 num_queries: 生成的查询变体数量。 rerank_top_n: 最终返回的文档数量。 collection_name: Qdrant 集合名称(仅当 retriever 未提供时使用)。 @@ -53,13 +56,26 @@ class RAGPipeline: ) else: self.retriever = retriever + + # 处理 llm 参数 + if llm == "default_small": + try: + self.llm = get_small_llm_service() + except Exception as e: + import logging + logger = logging.getLogger(__name__) + logger.warning(f"小模型初始化失败,将不做查询改写: {e}") + self.llm = None + elif llm in (None, False): + self.llm = None + else: + self.llm = llm - self.llm = llm self.num_queries = num_queries self.rerank_top_n = rerank_top_n # 初始化组件 - 使用统一的重排服务获取接口 - self.query_generator = MultiQueryGenerator(llm=llm, num_queries=num_queries) if llm else None + self.query_generator = MultiQueryGenerator(llm=self.llm, num_queries=num_queries) if self.llm else None self.reranker = create_document_reranker() async def aretrieve(self, query: str) -> List[Document]: @@ -102,11 +118,7 @@ class RAGPipeline: final_docs = fused_docs[:self.rerank_top_n] return final_docs - - def retrieve(self, query: str) -> List[Document]: - """同步检索入口(内部调用异步方法)""" - return asyncio.run(self.aretrieve(query)) - + def format_context(self, documents: List[Document]) -> str: """ 将文档列表格式化为上下文字符串 @@ -129,7 +141,7 @@ class RAGPipeline: def create_rag_pipeline( collection_name: str = "rag_documents", - llm: Optional[BaseLanguageModel] = None, + llm: Optional[BaseLanguageModel] = "default_small", num_queries: int = 3, rerank_top_n: int = 5, ) -> RAGPipeline: @@ -138,7 +150,10 @@ def create_rag_pipeline( Args: collection_name: Qdrant 集合名称 - llm: 用于生成多路查询的语言模型 + llm: 用于生成多路查询的语言模型。 + - "default_small": (默认) 使用小模型(本地 + DeepSeek) + - None / False: 不做查询改写 + - BaseLanguageModel 实例: 自定义模型 num_queries: 生成的查询变体数量 rerank_top_n: 最终返回的文档数量 diff --git a/backend/app/rag/retriever.py b/backend/app/rag/retriever.py index 4644580..a288970 100644 --- a/backend/app/rag/retriever.py +++ b/backend/app/rag/retriever.py @@ -33,16 +33,16 @@ DEFAULT_PARENT_SEARCH_K = 5 class HybridRetriever(BaseRetriever): """ 混合检索器:稠密向量 + BM25 稀疏向量 RRF 分数融合(异步) - + 使用 Qdrant Universal Query API (query_points) """ collection_name: str = Field(description="Qdrant 集合名称") search_k: int = Field(default=DEFAULT_SEARCH_K, description="检索返回结果数") - + _vector_store: Any = PrivateAttr() _client: Any = PrivateAttr() _sparse_embedder: Any = PrivateAttr() - + def __init__( self, collection_name: str, @@ -62,21 +62,39 @@ class HybridRetriever(BaseRetriever): self._vector_store = vector_store self._client = vector_store.get_async_qdrant_client() self._sparse_embedder = get_sparse_embedder() - + + def _get_relevant_documents( + self, query: str, *, run_manager: Any = None + ) -> List[Document]: + """ + 同步检索(不推荐使用,仅供兼容性) + + 注意:在异步环境中请使用 _aget_relevant_documents 或 ainvoke + """ + import asyncio + try: + loop = asyncio.get_running_loop() + # 已有事件循环,使用 create_task + task = loop.create_task(self._aget_relevant_documents(query)) + return loop.run_until_complete(task) + except RuntimeError: + # 没有事件循环,创建新的 + return asyncio.run(self._aget_relevant_documents(query)) + async def _aget_relevant_documents( - self, query: str, **kwargs + self, query: str, *, run_manager: Any = None ) -> List[Document]: """ 异步混合检索相关文档 """ # 1. 生成查询向量 - dense_query = await self._vector_store._aembed_query(query) + dense_query = await self._vector_store.aembed_query(query) sparse_query = self._sparse_embedder.embed_query(query) sparse_vec = models.SparseVector( indices=sparse_query["indices"], values=sparse_query["values"] ) - + # 2. 使用 Qdrant 的 query_points API response = await self._client.query_points( collection_name=self.collection_name, @@ -96,7 +114,7 @@ class HybridRetriever(BaseRetriever): limit=self.search_k, with_payload=True ) - + # 3. 转换结果 results = [] for point in response.points: @@ -105,28 +123,28 @@ class HybridRetriever(BaseRetriever): metadata=point.payload ) results.append(doc) - - debug(f"混合检索返回 %d 个文档", len(results)) + + debug(f"混合检索返回 {len(results)} 个文档") return results class ParentHybridRetriever(BaseRetriever): """ 父子文档混合检索器(异步): - + 1. 先用混合检索找到相关子文档 2. 根据子文档的 parent_id 找到对应的父文档 3. 去重并返回父文档 """ - + collection_name: str = Field(description="Qdrant 集合名称") search_k: int = Field(default=DEFAULT_PARENT_SEARCH_K, description="检索返回结果数") - + _vector_store: Any = PrivateAttr() _client: Any = PrivateAttr() _sparse_embedder: Any = PrivateAttr() _docstore: Any = PrivateAttr() - + def __init__( self, collection_name: str, @@ -149,24 +167,40 @@ class ParentHybridRetriever(BaseRetriever): self._client = vector_store.get_async_qdrant_client() self._sparse_embedder = get_sparse_embedder() self._docstore = docstore - + + def _get_relevant_documents( + self, query: str, *, run_manager: Any = None + ) -> List[Document]: + """ + 同步检索(不推荐使用,仅供兼容性) + + 注意:在异步环境中请使用 _aget_relevant_documents 或 ainvoke + """ + import asyncio + try: + loop = asyncio.get_running_loop() + task = loop.create_task(self._aget_relevant_documents(query)) + return loop.run_until_complete(task) + except RuntimeError: + return asyncio.run(self._aget_relevant_documents(query)) + async def _aget_relevant_documents( - self, query: str, **kwargs + self, query: str, *, run_manager: Any = None ) -> List[Document]: """ 异步检索相关父文档 """ # 1. 生成查询向量 - dense_query = await self._vector_store._aembed_query(query) + dense_query = await self._vector_store.aembed_query(query) sparse_query = self._sparse_embedder.embed_query(query) sparse_vec = models.SparseVector( indices=sparse_query["indices"], values=sparse_query["values"] ) - + # 2. 多取一些子文档,避免去重后数量不足 search_limit = self.search_k * 2 - + # 3. 使用 query_points API 进行混合检索 response = await self._client.query_points( collection_name=self.collection_name, @@ -186,30 +220,30 @@ class ParentHybridRetriever(BaseRetriever): limit=search_limit, with_payload=True ) - + if not response.points: debug("混合检索未找到任何文档") return [] - + # 4. 收集 parent_id 和对应最高得分 parent_score_map = {} parent_ids = set() child_point_map = {} # 保存子文档点用于降级 - + for point in response.points: payload_copy = point.payload.copy() parent_id = payload_copy.get("parent_id", point.id) score = point.score - + if parent_id not in parent_score_map or score > parent_score_map[parent_id]: parent_score_map[parent_id] = score parent_ids.add(parent_id) child_point_map[parent_id] = point - + # 5. 批量查询父文档 parent_docs = [] found_parent_ids = set() - + # 先尝试从 Qdrant 直接查询(如果父文档也在 Qdrant 中) try: parent_points = await self._client.retrieve( @@ -217,7 +251,7 @@ class ParentHybridRetriever(BaseRetriever): ids=list(parent_ids), with_payload=True ) - + for point in parent_points: payload_copy = point.payload.copy() doc = Document( @@ -226,10 +260,10 @@ class ParentHybridRetriever(BaseRetriever): ) parent_docs.append(doc) found_parent_ids.add(point.id) - + except Exception as e: - warning(f"从 Qdrant 查询父文档失败: %s", e) - + warning(f"从 Qdrant 查询父文档失败: {e}") + # 6. 如果有 docstore,尝试从 docstore 查询剩余的父文档 if self._docstore and len(found_parent_ids) < len(parent_ids): missing_parent_ids = parent_ids - found_parent_ids @@ -240,12 +274,12 @@ class ParentHybridRetriever(BaseRetriever): parent_docs.append(doc) found_parent_ids.add(doc_id) except Exception as e: - warning(f"从 docstore 查询父文档失败: %s", e) - + warning(f"从 docstore 查询父文档失败: {e}") + # 7. 降级:对于仍未找到的父文档,用子文档本身代替 missing_parent_ids = parent_ids - found_parent_ids if missing_parent_ids: - warning(f"以下 parent_id 未找到对应的父文档,将返回子文档本身: %s", missing_parent_ids) + warning(f"以下 parent_id 未找到对应的父文档,将返回子文档本身: {missing_parent_ids}") for parent_id in missing_parent_ids: child_point = child_point_map.get(parent_id) if child_point: @@ -255,17 +289,17 @@ class ParentHybridRetriever(BaseRetriever): metadata=payload_copy ) parent_docs.append(doc) - + # 8. 按照得分降序排序,返回前 k 个 parent_docs_with_scores = [ (doc, parent_score_map.get(doc.metadata.get("id", doc.id if hasattr(doc, "id") else ""), 0.0)) for doc in parent_docs ] parent_docs_with_scores.sort(key=lambda x: x[1], reverse=True) - + final_docs = [doc for doc, _ in parent_docs_with_scores[:self.search_k]] - debug(f"父子文档混合检索返回 %d 个父文档", len(final_docs)) - + debug(f"父子文档混合检索返回 {len(final_docs)} 个父文档") + return final_docs @@ -291,7 +325,7 @@ def create_hybrid_retriever( embeddings = get_embedding_service() info("使用统一嵌入服务(本地 llama.cpp → 智谱 API 自动降级)") - vector_store = QdrantHybridStore(collection_name=collection_name, embeddings=embeddings) + vector_store = QdrantHybridStore(collection_name=collection_name) try: vector_store.get_client().get_collection(collection_name) @@ -336,7 +370,7 @@ def create_parent_hybrid_retriever( embeddings = get_embedding_service() info("使用统一嵌入服务(本地 llama.cpp → 智谱 API 自动降级)") - vector_store = QdrantHybridStore(collection_name=collection_name, embeddings=embeddings) + vector_store = QdrantHybridStore(collection_name=collection_name) try: vector_store.get_client().get_collection(collection_name) diff --git a/backend/app/rag/tools.py b/backend/app/rag/tools.py index ee1b0d4..9a069ad 100644 --- a/backend/app/rag/tools.py +++ b/backend/app/rag/tools.py @@ -1,5 +1,5 @@ """ -RAG 工具模块 +RAG 工具模块(完全异步) 将检索功能封装为 LangChain Tool,供 Agent 调用。 采用固定流水线:多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档。 @@ -13,78 +13,24 @@ from langchain_core.retrievers import BaseRetriever from app.rag.pipeline import RAGPipeline, create_rag_pipeline -def create_rag_tool_sync( +def create_rag_tool( retriever: Optional[BaseRetriever] = None, - llm: Optional[BaseLanguageModel] = None, + llm: Optional[BaseLanguageModel] = "default_small", num_queries: int = 3, rerank_top_n: int = 5, collection_name: str = "rag_documents", ) -> Callable: """ - 创建一个配置好的 RAG 检索工具(同步版本)。 + 创建一个配置好的 RAG 检索工具(完全异步)。 默认使用混合检索(稠密+BM25稀疏)+ 父子文档模式。 Args: retriever: 基础检索器对象(可选,不提供则自动创建) - llm: 用于生成多路查询的语言模型(可选) - num_queries: 生成的查询变体数量 - rerank_top_n: 最终返回的文档数量 - collection_name: Qdrant 集合名称 - - Returns: - LangChain Tool 函数 - """ - pipeline = RAGPipeline( - retriever=retriever, - llm=llm, - num_queries=num_queries, - rerank_top_n=rerank_top_n, - collection_name=collection_name, - ) - - @tool - def search_knowledge_base_sync(query: str) -> str: - """ - 在知识库中搜索与查询相关的文档片段。 - - 使用混合检索(稠密向量语义 + BM25 关键词)+ 父子文档模式, - 检索效果最优。 - - Args: - query: 用户提出的问题或查询字符串 - - Returns: - 格式化后的相关文档内容 - """ - try: - documents = pipeline.retrieve(query) - if not documents: - return f"在知识库 '{collection_name}' 中未找到与 '{query}' 相关的信息。" - - context = pipeline.format_context(documents) - return context - except Exception as e: - return f"检索过程中发生错误: {str(e)}" - - return search_knowledge_base_sync - - -def create_rag_tool_async( - retriever: Optional[BaseRetriever] = None, - llm: Optional[BaseLanguageModel] = None, - num_queries: int = 3, - rerank_top_n: int = 5, - collection_name: str = "rag_documents", -) -> Callable: - """ - 创建一个配置好的 RAG 检索工具(异步版本)。 - - 默认使用混合检索(稠密+BM25稀疏)+ 父子文档模式。 - - Args: - retriever: 基础检索器对象(可选,不提供则自动创建) - llm: 用于生成多路查询的语言模型(可选) + llm: 用于生成多路查询的语言模型。 + - "default_small": (默认) 使用小模型(本地 + DeepSeek) + - None / False: 不做查询改写 + - BaseLanguageModel 实例: 自定义模型 num_queries: 生成的查询变体数量 rerank_top_n: 最终返回的文档数量 collection_name: Qdrant 集合名称 @@ -101,9 +47,9 @@ def create_rag_tool_async( ) @tool - async def search_knowledge_base_async(query: str) -> str: + async def search_knowledge_base(query: str) -> str: """ - 在知识库中搜索与查询相关的文档片段(异步版本)。 + 在知识库中搜索与查询相关的文档片段(完全异步)。 使用混合检索(稠密向量语义 + BM25 关键词)+ 父子文档模式, 检索效果最优。 @@ -124,30 +70,4 @@ def create_rag_tool_async( except Exception as e: return f"检索过程中发生错误: {str(e)}" - return search_knowledge_base_async - - -def create_rag_tool( - collection_name: str = "rag_documents", - llm: Optional[BaseLanguageModel] = None, - num_queries: int = 3, - rerank_top_n: int = 5, -) -> Callable: - """ - 创建 RAG 检索工具的便捷函数(同步版本)。 - - Args: - collection_name: Qdrant 集合名称 - llm: 用于生成多路查询的语言模型(可选) - num_queries: 生成的查询变体数量 - rerank_top_n: 最终返回的文档数量 - - Returns: - LangChain Tool 函数 - """ - return create_rag_tool_sync( - collection_name=collection_name, - llm=llm, - num_queries=num_queries, - rerank_top_n=rerank_top_n, - ) + return search_knowledge_base diff --git a/backend/rag_core/__init__.py b/backend/rag_core/__init__.py index f15a879..8d3512c 100644 --- a/backend/rag_core/__init__.py +++ b/backend/rag_core/__init__.py @@ -6,8 +6,13 @@ RAG Core - 公共 RAG 组件包 from .embedders import get_embeddings, get_embedding_dimension from .vector_store import QdrantHybridStore from .sparse_embedder import BM25SparseEmbedder, get_sparse_embedder -from .store import PostgresDocStore, create_docstore -from .client import create_qdrant_client, create_async_qdrant_client +from .doc_store import PostgresDocStore +from .client import ( + create_qdrant_client, + create_async_qdrant_client, + create_docstore, + get_docstore_uri +) from .config import ( QDRANT_URL, QDRANT_API_KEY, @@ -24,14 +29,15 @@ __all__ = [ "QdrantHybridStore", "BM25SparseEmbedder", "get_sparse_embedder", + "PostgresDocStore", + "create_docstore", + "get_docstore_uri", + "create_qdrant_client", + "create_async_qdrant_client", "QDRANT_URL", "QDRANT_API_KEY", "LLAMACPP_EMBEDDING_URL", "LLAMACPP_API_KEY", "DB_URI", "DOCSTORE_URI", - "PostgresDocStore", - "create_docstore", - "create_qdrant_client", - "create_async_qdrant_client", ] diff --git a/backend/rag_core/client.py b/backend/rag_core/client.py index 4931d04..0f2175d 100644 --- a/backend/rag_core/client.py +++ b/backend/rag_core/client.py @@ -1,7 +1,12 @@ # rag_core/client.py import os -from .config import QDRANT_URL, QDRANT_API_KEY +from .config import QDRANT_URL, QDRANT_API_KEY, DOCSTORE_URI from qdrant_client import QdrantClient, AsyncQdrantClient +from typing import Tuple +from langchain_core.stores import BaseStore +import logging + +logger = logging.getLogger(__name__) def create_qdrant_client(timeout: int = 300) -> QdrantClient: @@ -54,3 +59,47 @@ def create_async_qdrant_client(timeout: int = 300) -> AsyncQdrantClient: client_kwargs["api_key"] = QDRANT_API_KEY return AsyncQdrantClient(**client_kwargs) + + +def get_docstore_uri() -> str: + """获取 docstore 专用的数据库连接字符串(可与主库相同)""" + return DOCSTORE_URI + + +def create_docstore( + table_name: str = "parent_documents", + pool_config: dict | None = None, + max_concurrency: int | None = None +) -> Tuple[BaseStore, str]: + """ + 工厂函数,创建 PostgreSQL 文档存储。 + + Args: + table_name: PostgreSQL 表名(默认:parent_documents) + pool_config: 连接池配置 + max_concurrency: 最大并发操作数,如果为 None 则不限制 + + Returns: + 元组 (存储实例, 连接字符串) + + Raises: + ImportError: 缺少必要的依赖 + + Example: + >>> # 创建 PostgreSQL 存储 + >>> store, conn = create_docstore( + ... table_name="parent_docs", + ... max_concurrency=10 + ... ) + """ + from .doc_store import PostgresDocStore + + conn_str = get_docstore_uri() + store = PostgresDocStore( + connection_string=conn_str, + table_name=table_name, + pool_config=pool_config, + max_concurrency=max_concurrency + ) + logger.info(f"PostgreSQL docstore 已创建: {table_name}") + return store, conn_str diff --git a/backend/rag_core/store/postgres.py b/backend/rag_core/doc_store.py similarity index 96% rename from backend/rag_core/store/postgres.py rename to backend/rag_core/doc_store.py index 23b7153..ed7e9dd 100644 --- a/backend/rag_core/store/postgres.py +++ b/backend/rag_core/doc_store.py @@ -1,7 +1,7 @@ """ -异步 PostgreSQL 存储实现 - 用于生产环境。 +异步 PostgreSQL 文档存储 -使用 asyncpg 实现真正的异步 PostgreSQL 文档存储,支持高并发访问。 +用于 ParentDocumentRetriever 的父文档存储,支持高并发访问。 """ import asyncio @@ -16,6 +16,7 @@ import asyncpg logger = logging.getLogger(__name__) + class PostgresDocStore(BaseStore[str, Any]): """ 异步 PostgreSQL 文档存储实现。 @@ -49,7 +50,7 @@ class PostgresDocStore(BaseStore[str, Any]): Args: connection_string: PostgreSQL 连接 URL,格式: - "postgresql://user:password@host:port/database?sslmode=disable" + "postgresql://user:***@host:port/database?sslmode=disable" table_name: 存储表名,默认为 "parent_documents" pool_config: 连接池配置字典,包含: - min_size: 最小连接数(默认 2) @@ -57,17 +58,16 @@ class PostgresDocStore(BaseStore[str, Any]): max_concurrency: 最大并发操作数,如果为 None 则不限制 Raises: - ImportError: 未安装 asyncpg 时抛出 + ImportError: 缺少必要的依赖 Example: >>> store = PostgresDocStore( - ... "postgresql://user:pass@localhost:5432/mydb", + ... "postgresql://user:***@localhost:5432/mydb", ... table_name="parent_docs", ... pool_config={"min_size": 5, "max_size": 20}, ... max_concurrency=10 ... ) """ - self.dsn = connection_string self.table_name = table_name @@ -244,3 +244,4 @@ class PostgresDocStore(BaseStore[str, Any]): 注意:在异步环境中,请使用 aclose 方法。 """ pass + diff --git a/backend/rag_core/store/__init__.py b/backend/rag_core/store/__init__.py deleted file mode 100644 index 476dd6b..0000000 --- a/backend/rag_core/store/__init__.py +++ /dev/null @@ -1,29 +0,0 @@ -""" -文档存储模块 - 用于 ParentDocumentRetriever 的父文档存储。 - -提供 PostgreSQL 存储后端: -- PostgresDocStore: PostgreSQL 数据库存储(生产环境) - -示例用法: - >>> from rag_core.store import create_docstore - - >>> # 创建 PostgreSQL 存储 - >>> store, conn = create_docstore( - ... table_name="parent_docs" - ... ) -""" - - -from .postgres import PostgresDocStore -from .factory import create_docstore, get_docstore_uri - -__version__ = "2.0.0" - -__all__ = [ - # 具体实现 - "PostgresDocStore", - - # 工厂函数 - "create_docstore", - "get_docstore_uri", -] diff --git a/backend/rag_core/store/factory.py b/backend/rag_core/store/factory.py deleted file mode 100644 index bf9ade7..0000000 --- a/backend/rag_core/store/factory.py +++ /dev/null @@ -1,56 +0,0 @@ -""" -文档存储工厂 - 创建不同类型的存储实例。 - -提供统一的接口来创建本地文件存储或 PostgreSQL 存储。 -""" - -import os -from ..config import DOCSTORE_URI -import logging -from typing import Tuple - -from langchain_core.stores import BaseStore -from .postgres import PostgresDocStore - -logger = logging.getLogger(__name__) - - -def get_docstore_uri() -> str: - """获取 docstore 专用的数据库连接字符串(可与主库相同)""" - return DOCSTORE_URI - - -def create_docstore( - table_name: str = "parent_documents", - pool_config: dict | None = None, - max_concurrency: int | None = None -) -> Tuple[BaseStore, str]: - """ - 工厂函数,创建 PostgreSQL 文档存储。 - - Args: - table_name: PostgreSQL 表名(默认:parent_documents) - pool_config: 连接池配置 - max_concurrency: 最大并发操作数,如果为 None 则不限制 - - Returns: - 元组 (存储实例, 连接字符串) - - Raises: - ImportError: 缺少必要的依赖 - - Example: - >>> # 创建 PostgreSQL 存储 - >>> store, conn = create_docstore( - ... table_name="parent_docs", - ... max_concurrency=10 - ... ) - """ - conn_str = get_docstore_uri() - store = PostgresDocStore( - connection_string=conn_str, - table_name=table_name, - pool_config=pool_config, - max_concurrency=max_concurrency - ) - return store, conn_str diff --git a/backend/rag_core/vector_store.py b/backend/rag_core/vector_store.py index 6388b15..21e7e19 100644 --- a/backend/rag_core/vector_store.py +++ b/backend/rag_core/vector_store.py @@ -33,8 +33,6 @@ class QdrantHybridStore: def __init__( self, collection_name: str, - embeddings: Optional[Embeddings] = None, - sparse_embedder: Optional[BM25SparseEmbedder] = None, ): self.collection_name = collection_name self._client: Optional[QdrantClient] = None @@ -43,13 +41,10 @@ class QdrantHybridStore: self._last_connection_time: Optional[float] = None # 稠密嵌入模型 - if embeddings is None: - self.embeddings = get_embeddings() - else: - self.embeddings = embeddings + self.embeddings = get_embeddings() # 稀疏嵌入模型 - self.sparse_embedder = sparse_embedder or get_sparse_embedder() + self.sparse_embedder = get_sparse_embedder() # 集合初始化 self.create_collection() @@ -176,7 +171,7 @@ class QdrantHybridStore: texts = [doc.page_content for doc in documents] # 生成稠密向量 - dense_vectors = await self._aembed_texts(texts) + dense_vectors = await self.aembed_documents(texts) # 生成稀疏向量 sparse_vectors = self.sparse_embedder.embed_documents(texts) @@ -210,14 +205,18 @@ class QdrantHybridStore: return [p.id for p in points] - async def _aembed_texts(self, texts: List[str]) -> List[List[float]]: - """异步生成稠密向量(适配同步 Embeddings 接口)""" - # 注意:LangChain 的 Embeddings 接口目前主要是同步的 - # 使用线程池或直接调用(如果 embedding 内部有异步支持) + async def aembed_documents(self, texts: List[str]) -> List[List[float]]: + """异步生成文本列表的稠密向量""" import asyncio loop = asyncio.get_event_loop() return await loop.run_in_executor(None, self.embeddings.embed_documents, texts) + async def aembed_query(self, text: str) -> List[float]: + """异步生成查询的稠密向量""" + import asyncio + loop = asyncio.get_event_loop() + return await loop.run_in_executor(None, self.embeddings.embed_query, text) + # ---------- 异步检索方法 ---------- async def asimilarity_search(self, query: str, k: int = 5) -> List[Document]: """ @@ -227,7 +226,7 @@ class QdrantHybridStore: client = self.get_async_client() # 生成查询向量 - dense_query = await self._aembed_query(query) + dense_query = await self.aembed_query(query) sparse_query = self.sparse_embedder.embed_query(query) sparse_vec = models.SparseVector( indices=sparse_query["indices"], @@ -264,12 +263,6 @@ class QdrantHybridStore: logger.debug("混合检索返回 %d 个文档", len(results)) return results - async def _aembed_query(self, text: str) -> List[float]: - """异步生成查询稠密向量""" - import asyncio - loop = asyncio.get_event_loop() - return await loop.run_in_executor(None, self.embeddings.embed_query, text) - # ---------- 同步管理方法(保留,用于初始化和管理) ---------- def delete_collection(self): self.get_client().delete_collection(self.collection_name) diff --git a/rag_indexer/cli.py b/rag_indexer/cli.py index 6b6a4fd..5aa6076 100755 --- a/rag_indexer/cli.py +++ b/rag_indexer/cli.py @@ -6,10 +6,6 @@ import asyncio import logging import sys from pathlib import Path -from dotenv import load_dotenv - -# 加载 .env 文件 -load_dotenv() from rag_indexer.index_builder import IndexBuilder, IndexBuilderConfig from rag_indexer.splitters import SplitterType @@ -38,7 +34,7 @@ def get_input_path() -> Path: if len(sys.argv) > 1: return Path(sys.argv[1]) # 默认测试路径(可按需修改) - return Path("data/user_docs/doublestory.txt") + return Path("data/corpus/三国演义.txt") async def main(): diff --git a/rag_indexer/index_builder.py b/rag_indexer/index_builder.py index 75fcf48..e9eef0e 100644 --- a/rag_indexer/index_builder.py +++ b/rag_indexer/index_builder.py @@ -45,6 +45,11 @@ class IndexBuilderConfig: child_chunk_size: int = 200 child_chunk_overlap: int = 20 child_splitter_type: SplitterType = SplitterType.SEMANTIC # 子块默认语义切分 + # 子块语义切分参数 + child_buffer_size: int = 1 + child_breakpoint_threshold_type: str = "percentile" + child_breakpoint_threshold_amount: float = 90 # 降低阈值,让切分更激进 + child_min_chunk_size: int = 50 # 降低最小块大小 # 检索参数 search_k: int = 5 @@ -86,7 +91,6 @@ class IndexBuilder: # 初始化向量存储(自动支持稠密+稀疏混合检索) self.vector_store = QdrantHybridStore( collection_name=config.collection_name, - embeddings=self.embeddings, ) logger.info("✅ 混合检索向量存储初始化成功(稠密+BM25稀疏)") @@ -125,6 +129,10 @@ class IndexBuilder: self.child_splitter = get_splitter( SplitterType.SEMANTIC, embeddings=self.embeddings, + buffer_size=cfg.child_buffer_size, + breakpoint_threshold_type=cfg.child_breakpoint_threshold_type, + breakpoint_threshold_amount=cfg.child_breakpoint_threshold_amount, + min_chunk_size=cfg.child_min_chunk_size, **cfg.extra_splitter_kwargs ) else: diff --git a/tools/test/reset_qdrant.py b/rag_indexer/reset_qdrant.py similarity index 78% rename from tools/test/reset_qdrant.py rename to rag_indexer/reset_qdrant.py index 64cd6ed..305e11f 100644 --- a/tools/test/reset_qdrant.py +++ b/rag_indexer/reset_qdrant.py @@ -8,7 +8,6 @@ import os import sys from backend.rag_core import QdrantHybridStore -from backend.app.model_services import get_embedding_service async def delete_and_recreate(): @@ -17,8 +16,7 @@ async def delete_and_recreate(): print("删除旧集合并重新创建...") print("="*70) - embeddings = get_embedding_service() - vs = QdrantHybridStore(collection_name="rag_documents", embeddings=embeddings) + vs = QdrantHybridStore(collection_name="rag_documents") # 删除旧集合 try: diff --git a/tools/run.py b/tools/run.py new file mode 100644 index 0000000..f3acdb9 --- /dev/null +++ b/tools/run.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python3 +"""统一入口:设置路径后运行 RAG 索引构建 CLI""" +import sys +from pathlib import Path +from dotenv import load_dotenv + +# 路径设置 +project_root = Path(__file__).resolve().parent.parent +sys.path.insert(0, str(project_root)) +sys.path.insert(0, str(project_root / "backend")) +load_dotenv(project_root / ".env") + +if __name__ == "__main__": + from rag_indexer.cli import main + #from tools.test.test_rag_indexer_result import main + #from tools.test.test_rag_pipeline import main + import asyncio + asyncio.run(main()) \ No newline at end of file diff --git a/tools/test/check_qdrant.py b/tools/test/check_qdrant.py deleted file mode 100644 index 4f3bddf..0000000 --- a/tools/test/check_qdrant.py +++ /dev/null @@ -1,75 +0,0 @@ -#!/usr/bin/env python3 -""" -检查 Qdrant 集合里的数据结构 -""" - -import asyncio -import os -import sys - -from backend.rag_core import QdrantHybridStore -from backend.app.model_services import get_embedding_service - - -def check_qdrant_data(): - """检查 Qdrant 中的数据结构""" - print("="*70) - print("检查 Qdrant 中的数据结构...") - print("="*70) - - embeddings = get_embedding_service() - vs = QdrantHybridStore(collection_name="rag_documents", embeddings=embeddings) - client = vs.get_qdrant_client() - - # 先获取几个点看看 payload 结构 - print("\n获取 5 个随机文档:") - results = client.scroll( - collection_name="rag_documents", - limit=5, - with_payload=True, - with_vectors=True - ) - - for i, point in enumerate(results[0], 1): - print(f"\n{i}. ID: {point.id}") - print(f" Payload: {point.payload}") - print(f" Payload 键: {list(point.payload.keys())}") - if "text" in point.payload: - text = point.payload["text"] - print(f" Text 长度: {len(text)}") - print(f" Text 预览: {text[:150]}...") - if "page_content" in point.payload: - print(f" page_content: {point.payload['page_content'][:150]}...") - - # 看看向量 - if point.vector: - print(f" 向量存在: {type(point.vector)}") - if isinstance(point.vector, dict): - print(f" 向量键: {list(point.vector.keys())}") - - -def check_sparse_embedder(): - """检查稀疏嵌入器""" - from backend.rag_core import get_sparse_embedder - - print("\n" + "="*70) - print("检查稀疏嵌入器...") - print("="*70) - - sparse_embedder = get_sparse_embedder() - - print(f"\n稀疏嵌入器: {sparse_embedder}") - print(f"Vocabulary 大小: {len(sparse_embedder.model.vocab)}") - print(f"示例查询: '冬天 食物'") - - # 用中文试试 - sparse_vec = sparse_embedder.embed_query("冬天 食物") - print(f"\n生成的稀疏向量:") - print(f" 索引数量: {len(sparse_vec['indices'])}") - print(f" 索引: {sparse_vec['indices'][:10]}") - print(f" 值: {sparse_vec['values'][:10]}") - - -if __name__ == "__main__": - check_qdrant_data() - check_sparse_embedder() diff --git a/tools/test/quick_test.py b/tools/test/quick_test.py deleted file mode 100644 index 7c374bb..0000000 --- a/tools/test/quick_test.py +++ /dev/null @@ -1,37 +0,0 @@ -#!/usr/bin/env python3 -""" -简单测试脚本:测试文档里真正有的内容 -""" - -import asyncio -import os -import sys - -from qdrant_client import models -from backend.rag_core import QdrantHybridStore, get_sparse_embedder -from backend.app.model_services import get_embedding_service - - -def test_dense_retrieval(): - """测试稠密检索""" - print("="*70) - print("测试稠密检索...") - print("="*70) - - embeddings = get_embedding_service() - vs = QdrantHybridStore(collection_name="rag_documents", embeddings=embeddings) - - query = "黄双银" # 用文档里真正有的名字查询 - print(f"\n查询: {query}") - - results = vs.similarity_search(query, k=3) - - print(f"\n找到 {len(results)} 个结果\n") - for i, doc in enumerate(results): - print(f"--- 结果 {i+1} ---") - print(doc.page_content[:200]) - print() - - -if __name__ == "__main__": - test_dense_retrieval() diff --git a/tools/test/simple_delete.py b/tools/test/simple_delete.py deleted file mode 100644 index afdd498..0000000 --- a/tools/test/simple_delete.py +++ /dev/null @@ -1,27 +0,0 @@ -#!/usr/bin/env python3 -""" -简单删除 Qdrant 集合 -""" - -import sys -import os - -from backend.rag_core.client import create_qdrant_client - - -def delete_collection(): - print("="*70) - print("删除 rag_documents 集合...") - print("="*70) - - client = create_qdrant_client() - - try: - client.delete_collection("rag_documents") - print("✅ 删除成功") - except Exception as e: - print(f"⚠️ 删除失败: {e}") - - -if __name__ == "__main__": - delete_collection() diff --git a/tools/test/simple_test.py b/tools/test/simple_test.py deleted file mode 100644 index 7510776..0000000 --- a/tools/test/simple_test.py +++ /dev/null @@ -1,150 +0,0 @@ -#!/usr/bin/env python3 -""" -简单测试脚本:检查 Qdrant 内容,测试各种检索方式 -""" - -import asyncio -import os -import sys - -from qdrant_client import models -from backend.rag_core import QdrantHybridStore, get_sparse_embedder -from backend.app.model_services import get_embedding_service - - -def check_qdrant_content(): - """检查 Qdrant 里的内容""" - print("="*70) - print("检查 Qdrant 内容...") - print("="*70) - - embeddings = get_embedding_service() - vs = QdrantHybridStore(collection_name="rag_documents", embeddings=embeddings) - client = vs.get_qdrant_client() - - # 滚动获取前 5 个点 - points, _ = client.scroll( - collection_name="rag_documents", - limit=5, - with_payload=True, - with_vectors=False - ) - - print(f"\n找到 {len(points)} 个文档\n") - for i, point in enumerate(points): - print(f"--- 文档 {i+1} ---") - print(f"ID: {point.id}") - print(f"Payload 键: {list(point.payload.keys())}") - - # 打印完整 payload - for k, v in point.payload.items(): - if isinstance(v, str) and len(v) > 150: - v = v[:150] + "..." - print(f" {k}: {v}") - print() - - -def test_dense_retrieval(): - """测试稠密检索""" - print("="*70) - print("测试稠密检索...") - print("="*70) - - embeddings = get_embedding_service() - vs = QdrantHybridStore(collection_name="rag_documents", embeddings=embeddings) - - query = "蚂蚁" # 用中文查询 - print(f"\n查询: {query}") - - results = vs.similarity_search(query, k=3) - - print(f"\n找到 {len(results)} 个结果\n") - for i, doc in enumerate(results): - print(f"--- 结果 {i+1} ---") - print(doc.page_content[:200]) - print() - - -def test_sparse_retrieval(): - """测试稀疏检索""" - print("="*70) - print("测试稀疏检索(BM25)...") - print("="*70) - - embeddings = get_embedding_service() - vs = QdrantHybridStore(collection_name="rag_documents", embeddings=embeddings) - client = vs.get_qdrant_client() - sparse_embedder = get_sparse_embedder() - - query = "冬天" - print(f"\n查询: {query}") - - sparse_query = sparse_embedder.embed_query(query) - sparse_vec = models.SparseVector( - indices=sparse_query["indices"], - values=sparse_query["values"] - ) - - response = client.query_points( - collection_name="rag_documents", - query=sparse_vec, - using="sparse", - limit=3, - with_payload=True - ) - - print(f"\n找到 {len(response.points)} 个结果\n") - for i, point in enumerate(response.points): - print(f"--- 结果 {i+1} ---") - print(f"分数: {point.score:.4f}") - text = point.payload.get("page_content", point.payload.get("text", "")) - print(text[:200]) - print() - - -def test_hybrid_retrieval(): - """测试混合检索""" - print("="*70) - print("测试混合检索(稠密+稀疏 RRF 融合)...") - print("="*70) - - embeddings = get_embedding_service() - vs = QdrantHybridStore(collection_name="rag_documents", embeddings=embeddings) - client = vs.get_qdrant_client() - sparse_embedder = get_sparse_embedder() - - query = "蚂蚁和蚱蜢" - print(f"\n查询: {query}") - - dense_query = embeddings.embed_query(query) - sparse_query = sparse_embedder.embed_query(query) - sparse_vec = models.SparseVector( - indices=sparse_query["indices"], - values=sparse_query["values"] - ) - - response = client.query_points( - collection_name="rag_documents", - prefetch=[ - models.Prefetch(query=dense_query, using="dense", limit=3), - models.Prefetch(query=sparse_vec, using="sparse", limit=3) - ], - query=models.FusionQuery(fusion=models.Fusion.RRF), - limit=3, - with_payload=True - ) - - print(f"\n找到 {len(response.points)} 个结果\n") - for i, point in enumerate(response.points): - print(f"--- 结果 {i+1} ---") - print(f"分数: {point.score:.4f}") - text = point.payload.get("page_content", point.payload.get("text", "")) - print(text[:200]) - print() - - -if __name__ == "__main__": - check_qdrant_content() - test_dense_retrieval() - test_sparse_retrieval() - test_hybrid_retrieval() diff --git a/tools/test/test_backend.py b/tools/test/test_backend.py deleted file mode 100644 index 21f6f88..0000000 --- a/tools/test/test_backend.py +++ /dev/null @@ -1,303 +0,0 @@ -#!/usr/bin/env python3 -""" -完整后端测试 - 验证 Agent 所有功能 -包括:短期记忆、长期记忆、工具调用、流式对话、历史查询 -""" - -import asyncio -import os -import sys -import uuid -from dotenv import load_dotenv - -# 加载环境变量 -project_root = os.path.join(os.path.dirname(__file__), "..") -load_dotenv(os.path.join(project_root, ".env")) - -from backend.app.config import DB_URI -from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver -from backend.app.agent.agent_service import AIAgentService -from backend.app.agent.history import ThreadHistoryService -from backend.app.logger import info, warning, error - -# PostgreSQL 连接字符串 - -async def print_section(title): - """打印测试区块标题""" - print("\n" + "=" * 70) - print(f" {title}") - print("=" * 70) - -async def test_short_term_memory(agent_service): - """测试短期记忆(同一 thread_id 继续对话)""" - await print_section("测试 1: 短期记忆(Short-term Memory)") - - thread_id = str(uuid.uuid4()) - user_id = "test_user_memory" - - print(f"\n使用 thread_id: {thread_id[:8]}...") - print(f"使用 user_id: {user_id}") - - # 第一轮对话 - print("\n[第一轮] 发送消息: '我叫张三,今年28岁'") - result1 = await agent_service.process_message( - "我叫张三,今年28岁", thread_id, "local", user_id - ) - print(f"回复: {result1['reply'][:100]}...") - - # 第二轮对话 - 测试记忆 - print("\n[第二轮] 发送消息: '我叫什么名字?今年多大?'") - result2 = await agent_service.process_message( - "我叫什么名字?今年多大?", thread_id, "local", user_id - ) - print(f"回复: {result2['reply']}") - - # 验证记忆是否存在 - if "张三" in result2['reply'] or "28" in result2['reply']: - print("\n✅ 短期记忆测试通过!") - return True - else: - print("\n❌ 短期记忆测试失败!") - return False - -async def test_tool_calling(agent_service): - """测试工具调用(RAG 搜索)""" - await print_section("测试 2: 工具调用(Tool Calling)") - - thread_id = str(uuid.uuid4()) - user_id = "test_user_tools" - - print(f"\n使用 thread_id: {thread_id[:8]}...") - print(f"使用 user_id: {user_id}") - - # 发送需要 RAG 搜索的问题 - print("\n发送消息: '请告诉我,黄双银在魔王大陆的故事?'") - result = await agent_service.process_message( - "请告诉我,黄双银在魔王大陆的故事?", thread_id, "local", user_id - ) - print(f"回复: {result['reply'][:200]}...") - - # 检查是否调用了 RAG 工具(回复中会有黄双银相关内容) - if "黄双银" in result['reply']: - print("\n✅ 工具调用测试通过!") - return True - else: - print("\n⚠️ 工具调用测试结果不确定,需要手动验证") - return None - -async def test_streaming(agent_service): - """测试流式对话""" - await print_section("测试 3: 流式对话(Streaming)") - - thread_id = str(uuid.uuid4()) - user_id = "test_user_stream" - - print(f"\n使用 thread_id: {thread_id[:8]}...") - print(f"使用 user_id: {user_id}") - - print("\n发送消息: '用100字介绍一下AI人工智能' (流式)...") - print("流式输出: ", end="", flush=True) - - full_reply = "" - chunk_count = 0 - - try: - async for chunk in agent_service.process_message_stream( - "用100字介绍一下AI人工智能", thread_id, "local", user_id - ): - chunk_count += 1 - if chunk.get("type") == "llm_token": - token = chunk.get("token", "") - print(token, end="", flush=True) - full_reply += token - elif chunk.get("type") == "state_update": - pass # 状态更新不显示 - - print(f"\n\n共收到 {chunk_count} 个 chunk") - print(f"完整回复长度: {len(full_reply)} 字") - - if chunk_count > 0 and len(full_reply) > 10: - print("\n✅ 流式对话测试通过!") - return True - else: - print("\n❌ 流式对话测试失败!") - return False - - except Exception as e: - print(f"\n❌ 流式对话异常: {e}") - return False - -async def test_history_service(agent_service, history_service): - """测试历史查询服务""" - await print_section("测试 4: 历史查询服务(History Service)") - - user_id = "test_user_history" - - # 先创建几个对话 - print(f"\n为 user_id={user_id} 创建测试对话...") - - thread_ids = [] - for i in range(3): - thread_id = str(uuid.uuid4()) - thread_ids.append(thread_id) - - await agent_service.process_message( - f"这是第 {i+1} 个测试对话", thread_id, "local", user_id - ) - print(f" 创建线程 {i+1}: {thread_id[:8]}...") - - # 1. 测试获取用户线程列表 - print("\n[4.1] 测试获取用户线程列表...") - threads = await history_service.get_user_threads(user_id, limit=10) - print(f" 找到 {len(threads)} 个线程") - - if len(threads) >= 3: - print(" ✅ 线程列表查询通过") - else: - print(" ⚠️ 线程数量少于预期") - - # 2. 测试获取单个线程的消息历史 - if thread_ids: - test_thread_id = thread_ids[0] - print(f"\n[4.2] 测试获取线程消息历史 (thread_id={test_thread_id[:8]}...)") - messages = await history_service.get_thread_messages(test_thread_id) - print(f" 找到 {len(messages)} 条消息") - - if len(messages) >= 2: # 至少有一问一答 - print(" ✅ 消息历史查询通过") - else: - print(" ⚠️ 消息数量少于预期") - - # 3. 测试获取线程摘要 - print(f"\n[4.3] 测试获取线程摘要...") - summary = await history_service.get_thread_summary(test_thread_id) - print(f" 摘要: {summary.get('summary', '')[:50]}...") - print(f" 消息数: {summary.get('message_count', 0)}") - - if summary.get('message_count', 0) > 0: - print(" ✅ 线程摘要查询通过") - else: - print(" ⚠️ 摘要查询结果不确定") - - return len(threads) >= 3 - -async def test_long_term_memory(agent_service): - """测试长期记忆(mem0)""" - await print_section("测试 5: 长期记忆(Long-term Memory - mem0)") - - thread_id1 = str(uuid.uuid4()) - thread_id2 = str(uuid.uuid4()) # 不同的线程 - user_id = "test_user_longterm" - - print(f"\n使用 user_id: {user_id}") - print(f"线程 1: {thread_id1[:8]}...") - print(f"线程 2: {thread_id2[:8]}...") - - # 在第一个线程中保存信息 - print("\n[线程 1] 发送消息: '记住,我的宠物名字叫小白,是一只猫'") - result1 = await agent_service.process_message( - "记住,我的宠物名字叫小白,是一只猫", thread_id1, "local", user_id - ) - print(f"回复: {result1['reply'][:100]}...") - - # 等待一下,让 mem0 保存 - await asyncio.sleep(1) - - # 在第二个线程中询问(不同的 thread_id) - print("\n[线程 2] 发送消息: '我的宠物叫什么名字?是什么动物?'") - result2 = await agent_service.process_message( - "我的宠物叫什么名字?是什么动物?", thread_id2, "local", user_id - ) - print(f"回复: {result2['reply']}") - - # 验证长期记忆 - if "小白" in result2['reply'] or "猫" in result2['reply']: - print("\n✅ 长期记忆测试通过!") - return True - else: - print("\n⚠️ 长期记忆可能未启用,或需要手动验证") - return None - -async def main(): - """主测试函数""" - print("\n" + "=" * 70) - print(" 后端完整功能测试") - print("=" * 70) - - results = {} - - try: - # 创建数据库连接和服务 - print("\n正在初始化数据库连接...") - async with AsyncPostgresSaver.from_conn_string(DB_URI) as checkpointer: - await checkpointer.setup() - print("✅ 数据库连接成功") - - # 创建服务实例 - print("\n正在初始化 Agent 服务...") - agent_service = AIAgentService(checkpointer) - await agent_service.initialize() - print("✅ Agent 服务初始化成功") - - history_service = ThreadHistoryService(checkpointer) - print("✅ 历史服务初始化成功") - - print(f"\n可用模型: {list(agent_service.graphs.keys())}") - - # 运行测试 - results["短期记忆"] = await test_short_term_memory(agent_service) - await asyncio.sleep(1) - - results["工具调用"] = await test_tool_calling(agent_service) - await asyncio.sleep(1) - - results["流式对话"] = await test_streaming(agent_service) - await asyncio.sleep(1) - - results["历史查询"] = await test_history_service(agent_service, history_service) - await asyncio.sleep(1) - - results["长期记忆"] = await test_long_term_memory(agent_service) - await asyncio.sleep(1) - - # 打印总结 - await print_section("测试总结") - print("\n测试结果:") - print("-" * 40) - - pass_count = 0 - fail_count = 0 - skip_count = 0 - - for test_name, result in results.items(): - if result is True: - status = "✅ 通过" - pass_count += 1 - elif result is False: - status = "❌ 失败" - fail_count += 1 - else: - status = "⚠️ 待验证" - skip_count += 1 - print(f" {test_name:12s}: {status}") - - print("-" * 40) - print(f"总计: {len(results)} 个测试") - print(f"通过: {pass_count}, 失败: {fail_count}, 待验证: {skip_count}") - - if fail_count == 0: - print("\n🎉 所有核心测试通过!") - else: - print(f"\n⚠️ 有 {fail_count} 个测试失败") - - except Exception as e: - error(f"\n❌ 测试运行异常: {e}") - import traceback - traceback.print_exc() - return 1 - - return 0 if fail_count == 0 else 1 - -if __name__ == "__main__": - exit_code = asyncio.run(main()) - sys.exit(exit_code) diff --git a/tools/test/test_dqrant.py b/tools/test/test_dqrant.py deleted file mode 100644 index 385bc58..0000000 --- a/tools/test/test_dqrant.py +++ /dev/null @@ -1,63 +0,0 @@ -"""检查 Qdrant 中存储的向量质量。""" - -import os -import sys -import numpy as np -from dotenv import load_dotenv -from qdrant_client import QdrantClient - -# 加载环境变量 -project_root = os.path.join(os.path.dirname(__file__), "..") -load_dotenv(os.path.join(project_root, ".env")) - -from backend.rag_core import LlamaCppEmbedder - -QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333") -QDRANT_API_KEY = os.getenv("QDRANT_API_KEY") -COLLECTION_NAME = "rag_documents" - -client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY) -embedder = LlamaCppEmbedder() - -# 获取样本 -points, _ = client.scroll( - collection_name=COLLECTION_NAME, - limit=1, - with_vectors=True, - with_payload=True, -) - -if not points: - print(f"集合 '{COLLECTION_NAME}' 为空") - exit() - -sample = points[0] -raw_vec = sample.vector -if isinstance(raw_vec, dict): - stored_vec = list(raw_vec.values())[0] -elif isinstance(raw_vec, list): - stored_vec = raw_vec -else: - stored_vec = [] - -stored_payload = sample.payload or {} -stored_text = str(stored_payload.get("page_content", ""))[:200] - -print(f"内容预览:\n{stored_text}...\n") -print(f"向量维度: {len(stored_vec)}") # type: ignore -print(f"前5个值: {stored_vec[:5]}") # type: ignore -print(f"是否全零: {all(v == 0.0 for v in stored_vec)}") # type: ignore - -# 重新编码对比 -if stored_text: - new_vec = embedder.embed_query(stored_text) - similarity = np.dot(stored_vec, new_vec) / (np.linalg.norm(stored_vec) * np.linalg.norm(new_vec)) # type: ignore - print(f"\n重新编码前5个值: {new_vec[:5]}") - print(f"余弦相似度: {similarity:.4f}") - - if similarity < 0.8: - print("\n⚠️ 相似度过低,建议删除集合并重建索引") - else: - print("\n✅ 向量一致") -else: - print("\n⚠️ 样本无文本内容") diff --git a/tools/test/test_frontend.py b/tools/test/test_frontend.py deleted file mode 100644 index ea9958f..0000000 --- a/tools/test/test_frontend.py +++ /dev/null @@ -1,60 +0,0 @@ -#!/usr/bin/env python3 -""" -前端快速测试脚本 -验证前端导入是否正常工作 -""" - -import sys -import os - -print("=" * 60) -print("前端导入测试") -print("=" * 60) - -# 测试 1: 直接导入前端模块 -print("\n[测试 1] 直接导入前端模块...") -try: - from frontend.src.frontend_main import main - print("✅ frontend_main 导入成功") -except Exception as e: - print(f"❌ 导入失败: {e}") - sys.exit(1) - -# 测试 2: 导入配置 -print("\n[测试 2] 导入配置...") -try: - from frontend.src.config import config - print(f"✅ config 导入成功: page_title={config.page_title}") -except Exception as e: - print(f"❌ 导入失败: {e}") - -# 测试 3: 导入状态管理 -print("\n[测试 3] 导入状态管理...") -try: - from frontend.src.state import AppState - print("✅ AppState 导入成功") -except Exception as e: - print(f"❌ 导入失败: {e}") - -# 测试 4: 导入 API 客户端 -print("\n[测试 4] 导入 API 客户端...") -try: - from frontend.src.api_client import api_client - print("✅ api_client 导入成功") -except Exception as e: - print(f"❌ 导入失败: {e}") - -# 测试 5: 导入组件 -print("\n[测试 5] 导入组件...") -try: - from frontend.src.components.sidebar import render_sidebar - from frontend.src.components.chat_area import render_chat_area - from frontend.src.components.info_panel import render_info_panel - print("✅ 所有组件导入成功") -except Exception as e: - print(f"❌ 导入失败: {e}") - -print("\n" + "=" * 60) -print("🎉 所有前端导入测试通过!") -print("=" * 60) -print("\n现在可以使用 ./scripts/start.sh both 启动完整服务") diff --git a/tools/test/test_rag.py b/tools/test/test_rag.py deleted file mode 100644 index 39caf5d..0000000 --- a/tools/test/test_rag.py +++ /dev/null @@ -1,141 +0,0 @@ -#!/usr/bin/env python3 -""" -RAG 系统使用示例(重构版) - -演示: -1. 使用 IndexBuilder 获取父子块检索器 -2. 创建固定流程的 RAGPipeline(多路改写 → RRF融合 → 重排序 → 返回父文档) -3. 将流水线封装为 LangChain 工具,供 Agent 调用 -""" - -import asyncio -import sys -import os - -from dotenv import load_dotenv - -# 加载环境变量(Qdrant URL、PostgreSQL 连接等) -project_root = os.path.join(os.path.dirname(__file__), "..") -load_dotenv(os.path.join(project_root, ".env")) - -from pydantic import SecretStr -from langchain_openai import ChatOpenAI -from rag_indexer.index_builder import IndexBuilderConfig -from rag_indexer.splitters import SplitterType -from backend.app.rag.pipeline import RAGPipeline -from backend.app.rag.tools import create_rag_tool_sync -from backend.rag_core.retriever_factory import create_parent_retriever - -def create_llm(): - """创建本地 vLLM 服务 LLM""" - vllm_base_url = os.getenv( - "VLLM_BASE_URL", - "http://127.0.0.1:8081/v1" - ) - - return ChatOpenAI( - base_url=vllm_base_url, - api_key=SecretStr(os.getenv("LLAMACPP_API_KEY", "token-abc123")), - model="gemma-4-E2B-it", - timeout=60.0, # 请求超时时间(秒) - max_retries=2, # 失败后自动重试次数 - streaming=True, # 确保开启流式输出 - ) - -async def demonstrate_full_pipeline(): - """ - 完整流水线演示: - - 从 IndexBuilder 获取 ParentDocumentRetriever - - 创建 RAGPipeline - - 执行检索并打印结果 - """ - print("=" * 60) - print("演示:固定流程 RAG 检索(多路改写 + RRF + 重排序 + 父文档)") - print("=" * 60) - - retriever = create_parent_retriever(collection_name="rag_documents", search_k=5) - - if retriever is None: - print("错误:检索器未初始化,请确保索引已构建。") - return - - # 3. 创建 LLM 用于查询改写 - llm = create_llm() - - # 4. 创建 RAGPipeline(固定流程) - pipeline = RAGPipeline( - retriever=retriever, - llm=llm, - num_queries=3, # 生成 3 个查询变体 - rerank_top_n=5, # 最终返回 5 个父文档 - ) - - # 5. 执行检索 - query = "打虎英雄是谁?" - print(f"\n查询: {query}") - print("-" * 40) - - try: - documents = await pipeline.aretrieve(query) - print(f"返回 {len(documents)} 个父文档\n") - - # 打印结果预览 - for i, doc in enumerate(documents, 1): - content_preview = doc.page_content.replace("\n", " ")[:150] - source = doc.metadata.get("source", "未知来源") - print(f"{i}. 【来源:{source}】") - print(f" {content_preview}...\n") - - # 可选:格式化完整上下文 - # context = pipeline.format_context(documents) - # print(context) - - except Exception as e: - print(f"检索失败: {e}") - import traceback - traceback.print_exc() - -async def demonstrate_tool_creation(): - """ - 演示创建 RAG 工具(供 Agent 使用) - """ - print("\n" + "=" * 60) - print("演示:创建 RAG 工具(供 LangGraph Agent 调用)") - print("=" * 60) - - # 1. 获取检索器(同上) - config = IndexBuilderConfig( - collection_name="rag_documents", - splitter_type=SplitterType.PARENT_CHILD, - ) - retriever = create_parent_retriever(collection_name="rag_documents", search_k=5) - - # 2. 创建 LLM - llm = create_llm() - - # 3. 创建工具 - rag_tool = create_rag_tool_sync( - retriever=retriever, - llm=llm, - num_queries=3, - rerank_top_n=5, - collection_name="rag_documents", - ) - - print(f"工具名称: {rag_tool.name}") - print(f"工具描述: {rag_tool.description[:100]}...") - - # 4. 模拟 Agent 调用工具 - query = "请告诉我 打虎英雄是谁?" - print(f"\n模拟调用: {query}") - print("-" * 40) - - result = await rag_tool.ainvoke({"query": query}) - print(result[:800] + "..." if len(result) > 800 else result) - -async def main(): - await demonstrate_full_pipeline() - await demonstrate_tool_creation() - -if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file diff --git a/tools/test/test_rag_indexer_result.py b/tools/test/test_rag_indexer_result.py index 9e78f4c..af43350 100644 --- a/tools/test/test_rag_indexer_result.py +++ b/tools/test/test_rag_indexer_result.py @@ -1,305 +1,130 @@ #!/usr/bin/env python3 """ -测试重构后的 IndexBuilder 和 RAG 检索 -包括:索引构建、稠密检索、稀疏检索、混合检索、父子文档检索 +简单的 RAG 检索测试 +使用 app/rag/retriever 提供的功能 """ import asyncio -import os - -from rag_indexer.index_builder import IndexBuilder -from rag_indexer.splitters import SplitterType - -from backend.rag_core import QdrantHybridStore, get_sparse_embedder -from backend.app.model_services import get_embedding_service -from qdrant_client import models +from backend.app.rag.retriever import ( + create_parent_hybrid_retriever, + create_hybrid_retriever +) +from backend.rag_core import QdrantHybridStore -async def test_index_builder(): - """测试索引构建功能""" - print("="*70) - print("1. 测试索引构建功能...") - print("="*70) +# 统一的测试查询列表 +TEST_QUERIES = [ + "黄双银", +] + + +async def test_simple_vector_store_search(): + """测试:直接使用 QdrantHybridStore 的 asimilarity_search""" + print("="*80) + print("测试 1: QdrantHybridStore.asimilarity_search") + print("="*80) - # 创建 IndexBuilder 实例 - builder = IndexBuilder( + vs = QdrantHybridStore(collection_name="rag_documents") + + for query in TEST_QUERIES: + print(f"\n查询: {query}") + print("-" * 60) + + docs = await vs.asimilarity_search(query, k=10) + + if docs: + print(f"✓ 找到 {len(docs)} 个文档") + for i, doc in enumerate(docs, 1): + print(f"\n {i}. 来源: {doc.metadata.get('source', 'unknown')}") + preview = doc.page_content[:120].strip() + if len(doc.page_content) > 120: + preview += "..." + print(f" 内容: {preview}") + else: + print("✗ 未找到结果") + + await vs.close_async_client() + print("\n" + "="*80) + + +async def test_hybrid_retriever(): + """测试:HybridRetriever(子文档检索)""" + print("\n" + "="*80) + print("测试 2: HybridRetriever (子文档混合检索)") + print("="*80) + + retriever = create_hybrid_retriever( collection_name="rag_documents", - splitter_type=SplitterType.PARENT_CHILD, - parent_chunk_size=1000, - child_chunk_size=200 + search_k=10 ) - # 测试文档路径 - project_root = os.path.join(os.path.dirname(__file__), "..", "..") - test_file = os.path.join(project_root, "data", "user_docs", "doublestory.txt") - - if os.path.exists(test_file): - # 构建索引 - print(f"正在为文件 {test_file} 构建索引...") - processed = await builder.build_from_file(test_file) - print(f"索引构建完成,处理了 {processed} 个文档") + for query in TEST_QUERIES: + print(f"\n查询: {query}") + print("-" * 60) - # 获取集合信息 - info = builder.get_collection_info() - print(f"集合信息: {info}") - else: - print(f"测试文件不存在: {test_file}") + docs = await retriever.ainvoke(query) + + if docs: + print(f"✓ 找到 {len(docs)} 个子文档") + for i, doc in enumerate(docs, 1): + print(f"\n {i}. parent_id: {doc.metadata.get('parent_id', 'none')}") + preview = doc.page_content[:100].strip() + if len(doc.page_content) > 100: + preview += "..." + print(f" 内容: {preview}") + else: + print("✗ 未找到结果") - # 关闭资源 - builder.close() - print("\n索引构建测试完成") - return processed + print("\n" + "="*80) -def test_dense_retrieval(): - """测试稠密检索""" - print("\n" + "="*70) - print("2. 测试稠密检索...") - print("="*70) +async def test_parent_hybrid_retriever(): + """测试:ParentHybridRetriever(父子文档混合检索)""" + print("\n" + "="*80) + print("测试 3: ParentHybridRetriever (父子文档混合检索)") + print("="*80) - # 获取嵌入服务 - embeddings = get_embedding_service() - - # 创建向量存储 - vs = QdrantHybridStore(collection_name="rag_documents", embeddings=embeddings) - - # 测试查询 - query = "The Ant and the Grasshopper" - print(f"查询: {query}") - - results = vs.similarity_search(query, k=3) - - print(f"\n找到 {len(results)} 个结果:") - for i, doc in enumerate(results, 1): - print(f"\n{i}. (来源: {doc.metadata.get('source', 'unknown')})") - print(f" 元数据: {doc.metadata}") - content = doc.page_content.strip() - if len(content) > 200: - content = content[:200] + "..." - print(f" 内容: {content}") - - -def test_sparse_retrieval_simple(): - """简单测试稀疏检索""" - print("\n" + "="*70) - print("3. 测试稀疏检索(BM25)...") - print("="*70) - - # 获取嵌入服务和稀疏嵌入器 - embeddings = get_embedding_service() - vs = QdrantHybridStore(collection_name="rag_documents", embeddings=embeddings) - client = vs.get_qdrant_client() - sparse_embedder = get_sparse_embedder() - - # 测试查询 - 用关键词 - query = "winter work food" - print(f"查询关键词: {query}") - - # 生成稀疏查询向量 - sparse_query = sparse_embedder.embed_query(query) - - # 包装成 SparseVector 对象 - sparse_vec = models.SparseVector( - indices=sparse_query["indices"], - values=sparse_query["values"] - ) - - # 直接查询稀疏向量 - response = client.query_points( + retriever = create_parent_hybrid_retriever( collection_name="rag_documents", - query=sparse_vec, - using="sparse", - limit=3, - with_payload=True + search_k=10 ) - print(f"\n找到 {len(response.points)} 个结果:") - for i, point in enumerate(response.points, 1): - print(f"\n{i}. (分数: {point.score:.4f})") - text = point.payload.get("text", "") - metadata = {k: v for k, v in point.payload.items() if k != "text"} - print(f" 元数据: {metadata}") - content = text.strip() - if len(content) > 200: - content = content[:200] + "..." - print(f" 内容: {content}") - - -def test_hybrid_retrieval_simple(): - """简单测试混合检索(稠密+稀疏 RRF 融合)""" - print("\n" + "="*70) - print("4. 测试混合检索(稠密+稀疏 RRF 融合)...") - print("="*70) - - # 获取嵌入服务和稀疏嵌入器 - embeddings = get_embedding_service() - vs = QdrantHybridStore(collection_name="rag_documents", embeddings=embeddings) - client = vs.get_qdrant_client() - sparse_embedder = get_sparse_embedder() - - # 测试查询 - query = "Ant and Grasshopper story" - print(f"查询: {query}") - - # 生成双向量 - dense_query = embeddings.embed_query(query) - sparse_query = sparse_embedder.embed_query(query) - sparse_vec = models.SparseVector( - indices=sparse_query["indices"], - values=sparse_query["values"] - ) - - # 使用 Qdrant 的 query_points 做混合检索 - response = client.query_points( - collection_name="rag_documents", - prefetch=[ - models.Prefetch( - query=dense_query, - using="dense", - limit=3 - ), - models.Prefetch( - query=sparse_vec, - using="sparse", - limit=3 - ) - ], - query=models.FusionQuery(fusion=models.Fusion.RRF), - limit=3, - with_payload=True - ) - - print(f"\n找到 {len(response.points)} 个结果:") - for i, point in enumerate(response.points, 1): - print(f"\n{i}. (RRF 融合分数: {point.score:.4f})") - text = point.payload.get("text", "") - metadata = {k: v for k, v in point.payload.items() if k != "text"} - print(f" 元数据: {metadata}") - content = text.strip() - if len(content) > 200: - content = content[:200] + "..." - print(f" 内容: {content}") - - -def test_parent_child_retrieval_simple(): - """简单测试父子文档检索""" - print("\n" + "="*70) - print("5. 测试父子文档混合检索...") - print("="*70) - - # 获取嵌入服务和稀疏嵌入器 - embeddings = get_embedding_service() - vs = QdrantHybridStore(collection_name="rag_documents", embeddings=embeddings) - client = vs.get_qdrant_client() - sparse_embedder = get_sparse_embedder() - - # 测试查询 - query = "The Ant and the Grasshopper story moral" - print(f"查询: {query}") - - # 生成双向量 - dense_query = embeddings.embed_query(query) - sparse_query = sparse_embedder.embed_query(query) - sparse_vec = models.SparseVector( - indices=sparse_query["indices"], - values=sparse_query["values"] - ) - - # 先做混合检索找到子文档 - response = client.query_points( - collection_name="rag_documents", - prefetch=[ - models.Prefetch( - query=dense_query, - using="dense", - limit=5 - ), - models.Prefetch( - query=sparse_vec, - using="sparse", - limit=5 - ) - ], - query=models.FusionQuery(fusion=models.Fusion.RRF), - limit=5, - with_payload=True - ) - - # 收集 parent_id - parent_score_map = {} - child_points = {} - for point in response.points: - parent_id = point.payload.get("parent_id", point.id) - score = point.score - if parent_id not in parent_score_map or score > parent_score_map[parent_id]: - parent_score_map[parent_id] = score - child_points[parent_id] = point - - parent_ids = list(parent_score_map.keys()) - - print(f"\n找到 {len(parent_ids)} 个不同的 parent_id:") - - # 查找父文档 - if parent_ids: - parent_docs = client.retrieve( - collection_name="rag_documents", - ids=parent_ids, - with_payload=True - ) + for query in TEST_QUERIES: + print(f"\n查询: {query}") + print("-" * 60) - found_parent_ids = {p.id for p in parent_docs} + docs = await retriever.ainvoke(query) - # 准备结果列表 - results = [] - for p in parent_docs: - score = parent_score_map[p.id] - results.append((p, score)) - - # 处理没找到父文档的情况 - 用子文档代替 - missing = set(parent_ids) - found_parent_ids - for parent_id in missing: - child_point = child_points[parent_id] - print(f"\n注意: parent_id {parent_id} 未找到,使用子文档代替") - results.append((child_point, parent_score_map[parent_id])) - - # 按分数排序 - results.sort(key=lambda x: x[1], reverse=True) - - # 显示 - print(f"\n共 {len(results)} 个结果(去重后):") - for i, (point, score) in enumerate(results[:3], 1): - print(f"\n{i}. (分数: {score:.4f})") - text = point.payload.get("text", "") - metadata = {k: v for k, v in point.payload.items() if k != "text"} - print(f" 元数据: {metadata}") - content = text.strip() - if len(content) > 400: - content = content[:400] + "..." - print(f" 内容: {content}") - else: - print("\n未找到结果") + if docs: + print(f"✓ 找到 {len(docs)} 个父文档") + for i, doc in enumerate(docs, 1): + print(f"\n {i}. 来源: {doc.metadata.get('source', 'unknown')}") + preview = doc.page_content[:150].strip() + if len(doc.page_content) > 150: + preview += "..." + print(f" 内容:\n {preview}") + else: + print("✗ 未找到结果") + + print("\n" + "="*80) async def main(): """主测试函数""" - # 1. 先构建索引 - await test_index_builder() + print("\n" + "="*80) + print("RAG 检索功能测试") + print("="*80) - # 2. 测试稠密检索 - test_dense_retrieval() + # 测试 1: 直接使用 vector store + await test_simple_vector_store_search() - # 3. 测试稀疏检索 - test_sparse_retrieval_simple() + # 测试 2: HybridRetriever + await test_hybrid_retriever() - # 4. 测试混合检索 - test_hybrid_retrieval_simple() + # 测试 3: ParentHybridRetriever + await test_parent_hybrid_retriever() - # 5. 测试父子文档检索 - test_parent_child_retrieval_simple() - - print("\n" + "="*70) - print("所有测试完成!") - print("="*70) + print("\n🎉 所有测试完成!") if __name__ == "__main__": diff --git a/tools/test/test_rag_pipeline.py b/tools/test/test_rag_pipeline.py new file mode 100644 index 0000000..cbe60a6 --- /dev/null +++ b/tools/test/test_rag_pipeline.py @@ -0,0 +1,145 @@ +#!/usr/bin/env python3 +""" +完整的 RAG Pipeline 测试 +测试从查询改写 → 检索 → RRF融合 → 重排序 → 格式化输出的整个流程 +""" + +import asyncio +from backend.app.rag.pipeline import RAGPipeline, create_rag_pipeline +from backend.app.rag.tools import create_rag_tool + + +async def test_rag_pipeline_direct(): + """测试 1: 直接使用 RAGPipeline(默认用小模型做查询改写)""" + print("="*80) + print("测试 1: 直接使用 RAGPipeline(默认用小模型做查询改写)") + print("="*80) + + # 创建 pipeline(默认用小模型) + pipeline = create_rag_pipeline( + collection_name="rag_documents", + num_queries=3, + rerank_top_n=5 + ) + + query = "黄双银的经历" + + print(f"\n用户查询: {query}") + print("-" * 80) + + # 执行检索 + docs = await pipeline.aretrieve(query) + + if docs: + print(f"\n✓ 找到 {len(docs)} 个相关文档") + print("-" * 80) + + for i, doc in enumerate(docs, 1): + print(f"\n{i}. 来源: {doc.metadata.get('source', 'unknown')}") + print(f" 内容:\n{doc.page_content}") + print("-" * 80) + + # 格式化输出 + print("\n" + "="*80) + print("格式化后的上下文:") + print("="*80) + formatted_context = pipeline.format_context(docs) + print(formatted_context) + else: + print("\n✗ 未找到相关文档") + + print("\n" + "="*80) + + +async def test_rag_tool(): + """测试 2: 使用 RAG Tool(默认用小模型做查询改写)""" + print("\n"+"="*80) + print("测试 2: 使用 RAG Tool(默认用小模型做查询改写)") + print("="*80) + + # 创建 tool(默认用小模型) + rag_tool = create_rag_tool( + collection_name="rag_documents", + num_queries=3, + rerank_top_n=5 + ) + + query = "黄双银的经历" + + print(f"\n用户查询: {query}") + print("-" * 80) + + # 使用 tool (异步调用 ainvoke) + result = await rag_tool.ainvoke(query) + + print("\nTool 返回结果:") + print("="*80) + print(result) + print("="*80) + + +async def test_custom_pipeline(): + """测试 3: 自定义参数的 RAGPipeline(默认用小模型)""" + print("\n"+"="*80) + print("测试 3: 自定义参数的 RAGPipeline(默认用小模型)") + print("="*80) + + # 自定义参数(默认用小模型) + pipeline = RAGPipeline( + collection_name="rag_documents", + num_queries=2, # 只生成 2 个查询变体 + rerank_top_n=3 # 只返回前 3 个最相关文档 + ) + + query = "黄双银的经历" + + print(f"\n用户查询: {query}") + print(f"配置: num_queries=2, rerank_top_n=3") + print("-" * 80) + + docs = await pipeline.aretrieve(query) + + if docs: + print(f"\n✓ 找到 {len(docs)} 个相关文档") + print("-" * 80) + + for i, doc in enumerate(docs, 1): + print(f"\n{i}. 来源: {doc.metadata.get('source', 'unknown')}") + preview = doc.page_content[:200].strip() + if len(doc.page_content) > 200: + preview += "..." + print(f" 内容预览: {preview}") + + print("\n" + "="*80) + print("格式化后的上下文:") + print("="*80) + print(pipeline.format_context(docs)) + else: + print("\n✗ 未找到相关文档") + + print("\n" + "="*80) + + +async def main(): + """主测试函数""" + print("\n" + "="*80) + print("完整 RAG Pipeline 测试") + print("查询: '黄双银的经历'") + print("="*80) + + # 测试 1: 直接使用 pipeline + await test_rag_pipeline_direct() + + # 测试 2: 使用 tool + await test_rag_tool() + + # 测试 3: 自定义参数 + await test_custom_pipeline() + + print("\n" + "="*80) + print("🎉 所有 RAG Pipeline 测试完成!") + print("="*80) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/tools/test/test_retrievers.py b/tools/test/test_retrievers.py deleted file mode 100644 index b46b46c..0000000 --- a/tools/test/test_retrievers.py +++ /dev/null @@ -1,51 +0,0 @@ -#!/usr/bin/env python3 -""" -测试 app/rag/retriever.py 里的混合检索函数 -""" - -import asyncio -import os -import sys - -from backend.app.rag.retriever import create_hybrid_retriever, create_parent_hybrid_retriever - - -def test_hybrid_retriever(): - """测试混合检索器""" - print("="*70) - print("测试 HybridRetriever...") - print("="*70) - - retriever = create_hybrid_retriever(collection_name="rag_documents", search_k=3) - results = retriever.invoke("黄双银") - - print(f"\n找到 {len(results)} 个结果\n") - for i, doc in enumerate(results): - print(f"--- 结果 {i+1} ---") - print(doc.page_content[:200]) - print() - - -def test_parent_hybrid_retriever(): - """测试父子混合检索器""" - print("\n" + "="*70) - print("测试 ParentHybridRetriever...") - print("="*70) - - retriever = create_parent_hybrid_retriever( - collection_name="rag_documents", - search_k=3, - use_docstore=False - ) - results = retriever.invoke("黄双银") - - print(f"\n找到 {len(results)} 个结果\n") - for i, doc in enumerate(results): - print(f"--- 结果 {i+1} ---") - print(doc.page_content[:300]) - print() - - -if __name__ == "__main__": - test_hybrid_retriever() - test_parent_hybrid_retriever()