From 82dde7113e5d0813e0c30e53a5ee63f434102213 Mon Sep 17 00:00:00 2001 From: root <953994191@qq.com> Date: Mon, 4 May 2026 04:28:32 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9rag=EF=BC=8C=E5=AE=9E?= =?UTF-8?q?=E7=8E=B0=E6=B7=B7=E5=90=88=E6=A3=80=E7=B4=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/app/agent/agent_service.py | 39 ++--- backend/app/agent/history.py | 9 +- backend/app/backend.py | 14 +- backend/app/core/intent.py | 12 +- backend/app/main_graph/nodes/rag_nodes.py | 38 +++-- backend/app/main_graph/nodes/react_nodes.py | 24 ++- .../app/main_graph/utils/rag_initializer.py | 6 +- backend/app/rag/tools.py | 57 +++++++ frontend/src/config.py | 4 +- tools/test/check_qdrant.py | 80 +++++++++ tools/test/quick_test.py | 40 +++++ tools/test/reset_qdrant.py | 41 +++++ tools/test/simple_delete.py | 30 ++++ tools/test/simple_test.py | 153 ++++++++++++++++++ tools/test/test_retrievers.py | 54 +++++++ 15 files changed, 536 insertions(+), 65 deletions(-) create mode 100644 tools/test/check_qdrant.py create mode 100644 tools/test/quick_test.py create mode 100644 tools/test/reset_qdrant.py create mode 100644 tools/test/simple_delete.py create mode 100644 tools/test/simple_test.py create mode 100644 tools/test/test_retrievers.py diff --git a/backend/app/agent/agent_service.py b/backend/app/agent/agent_service.py index 5b31f22..24c134b 100644 --- a/backend/app/agent/agent_service.py +++ b/backend/app/agent/agent_service.py @@ -70,7 +70,7 @@ class AIAgentService: raise RuntimeError("没有可用的模型") return self - async def process_message(self, message: str, thread_id: str, model: str = "zhipu", user_id: str = "default_user") -> dict: + async def process_message(self, message: str, thread_id: str, model: str = "local", user_id: str = "default_user") -> dict: """处理用户消息,返回包含回复、token统计和耗时的字典""" if model not in self.graphs: # 回退到第一个可用模型 @@ -175,6 +175,8 @@ class AIAgentService: try: info(f"📡 开始调用 graph.astream()...") chunk_count = 0 + full_message_content = "" # 收集完整消息内容 + async for chunk in graph.astream( input_state, config=config, @@ -184,21 +186,11 @@ class AIAgentService: ): chunk_count += 1 chunk_type = chunk["type"] - info(f"📦 收到第 {chunk_count} 个chunk, type: {chunk_type}") processed_event = {} if chunk_type == "messages": message_chunk, metadata = chunk["data"] node_name = metadata.get("langgraph_node", "unknown") - info(f"📨 处理消息chunk, node: {node_name}") - # 详细记录消息内容,看看这些 chunk 到底是什么 - if hasattr(message_chunk, "content"): - content_preview = str(message_chunk.content)[:200] - info(f"📄 消息内容预览 ({len(content_preview)} chars): {repr(content_preview)}") - if hasattr(message_chunk, "type"): - info(f"📋 消息类型: {message_chunk.type}") - if hasattr(message_chunk, "tool_calls"): - info(f"🔧 包含工具调用: {message_chunk.tool_calls}") # 检测节点变化,发送节点开始事件 if node_name != current_node: @@ -218,8 +210,6 @@ class AIAgentService: reasoning_token = "" if hasattr(message_chunk, 'additional_kwargs'): reasoning_token = message_chunk.additional_kwargs.get("reasoning_content", "") - - info(f"💬 消息token: token_content='{repr(token_content[:50])}', reasoning_token='{repr(reasoning_token[:50])}', node_name='{node_name}'") # 处理思考过程 if reasoning_token: @@ -228,7 +218,6 @@ class AIAgentService: "node": node_name, "reasoning_token": reasoning_token } - info(f"✅ 生成 reasoning_token 事件: {processed_event}") # 处理工具调用 elif hasattr(message_chunk, 'tool_calls') and message_chunk.tool_calls: for tool_call in message_chunk.tool_calls: @@ -248,7 +237,7 @@ class AIAgentService: "args": tool_args, "id": tool_call_id } - # 处理普通 token + # 处理普通 token - 只收集,不打印单个 token elif token_content: processed_event = { "type": "llm_token", @@ -256,18 +245,13 @@ class AIAgentService: "token": token_content, "reasoning_token": reasoning_token } - info(f"✅ 生成 llm_token 事件: {processed_event}") - else: - info(f"⚠️ 没有生成任何事件,token_content='{repr(token_content)}', reasoning_token='{repr(reasoning_token)}'") + if node_name == "llm_call": + full_message_content += token_content elif chunk_type == "updates": - info(f"🔄 处理updates chunk") updates_data = chunk["data"] serialized_data = self._serialize_value(updates_data) - # 关键修复:不再从 updates 中读取 latest_reasoning,避免重复 - # 因为我们现在直接通过 custom 事件发送推理结果了 - # 检查是否有人工审核请求 if "review_pending" in serialized_data and serialized_data["review_pending"]: review_id = serialized_data.get("review_id", "") @@ -302,18 +286,12 @@ class AIAgentService: } elif chunk_type == "custom": - info(f"🎯 处理custom chunk, 完整数据: {repr(chunk)}") custom_data = chunk["data"] - info(f"🎯 custom_data 内容: {repr(custom_data)}") - info(f"🎯 custom_data 类型: {type(custom_data)}") - # 关键修复:处理我们从 react_reason_node 发送的自定义推理事件 - # LangGraph 的 adispatch_custom_event 发送的事件格式: - # chunk["data"] 是我们传的第二个参数(dict) + # 处理我们从 react_reason_node 发送的自定义推理事件 if isinstance(custom_data, dict): # 检查是否是我们的推理事件 if "action" in custom_data and "reasoning" in custom_data: - info(f"[Agent Service] 收到自定义推理事件: {custom_data}") yield { "type": "react_reasoning", "step": custom_data.get("step", 1), @@ -339,7 +317,10 @@ class AIAgentService: if processed_event: yield processed_event + # 完整消息集合完成后,一次性打印 info(f"✅ graph.astream() 完成,共 {chunk_count} 个chunks") + if full_message_content: + info(f"📄 完整消息内容: {repr(full_message_content)}") except Exception as e: error(f"❌ 执行 React 图时出错: {e}") diff --git a/backend/app/agent/history.py b/backend/app/agent/history.py index 09f7124..c2107fe 100644 --- a/backend/app/agent/history.py +++ b/backend/app/agent/history.py @@ -12,18 +12,21 @@ class ThreadHistoryService: def __init__(self, checkpointer): self.checkpointer = checkpointer - async def get_user_threads(self, user_id: str, limit: int = 50) -> List[Dict[str, Any]]: + async def get_user_threads(self, user_id: str, limit: int = 4) -> List[Dict[str, Any]]: """ 获取指定用户的所有线程摘要信息 Args: user_id: 用户 ID - limit: 返回数量限制 + limit: 返回数量限制(强制最多4条) Returns: 线程列表,每个包含 thread_id, last_updated, summary, message_count """ try: + # 强制限制最多4条 + actual_limit = min(limit, 4) + # 查询 checkpoints 表获取用户的线程列表 async with self.checkpointer.conn.cursor() as cur: # 在较新的 LangGraph 版本中,AsyncPostgresSaver 创建的 checkpoints 表 @@ -40,7 +43,7 @@ class ThreadHistoryService: ORDER BY last_updated DESC LIMIT %s """ - await cur.execute(query, (user_id, limit)) + await cur.execute(query, (user_id, actual_limit)) rows = await cur.fetchall() threads = [] diff --git a/backend/app/backend.py b/backend/app/backend.py index e1b0d52..76b7675 100644 --- a/backend/app/backend.py +++ b/backend/app/backend.py @@ -98,7 +98,7 @@ async def health_check(): class ChatRequest(BaseModel): message: str thread_id: str | None = None - model: str = "zhipu" + model: str = "local" user_id: str = "default_user" class ChatResponse(BaseModel): @@ -212,7 +212,7 @@ async def chat_endpoint( @app.get("/threads") async def list_threads( user_id: str = Query("default_user", description="用户 ID"), - limit: int = Query(50, ge=1, le=200, description="返回数量限制"), + limit: int = Query(4, ge=1, le=200, description="返回数量限制"), history_service: ThreadHistoryService = Depends(get_history_service) ): """获取当前用户的对话历史列表""" @@ -312,7 +312,7 @@ async def websocket_endpoint( data = await websocket.receive_json() message = data.get("message") thread_id = data.get("thread_id", str(uuid.uuid4())) - model = data.get("model", "zhipu") + model = data.get("model", "local") user_id = data.get("user_id", "default_user") if not message: await websocket.send_json({"error": "missing message"}) @@ -435,4 +435,10 @@ if __name__ == "__main__": import uvicorn # 使用环境变量或默认端口 8079(避免与 llama.cpp 的 8081 端口冲突) port = int(BACKEND_PORT) - uvicorn.run(app, host="0.0.0.0", port=port) + uvicorn.run( + app, + host="0.0.0.0", + port=port, + log_level="debug", + access_log=True + ) diff --git a/backend/app/core/intent.py b/backend/app/core/intent.py index 58b3262..dd7639f 100644 --- a/backend/app/core/intent.py +++ b/backend/app/core/intent.py @@ -130,12 +130,16 @@ class ReactIntentReasoner: retrieved_docs = context.get("retrieved_docs", []) messages = context.get("messages", []) - # 关键修复 2:如果已经有 rag_context 或 web_search_results(通过 messages 推断),直接回答 - # 检查是否已经执行过 rag_retrieve 或 web_search - if "rag_retrieve" in previous_actions or "web_search" in previous_actions: + # 关键修改:不要在第一次 rag_retrieve 后就直接回答,允许再推理一次 + # 让推理逻辑有机会判断 RAG 结果好不好,要不要再检索或转 web search + rag_count = previous_actions.count("rag_retrieve") + web_search_count = previous_actions.count("web_search") + + # 只有当 rag 或 web search 已经超过 1 次,或者已经有推理在 rag 之后,才直接回答 + if rag_count >= 2 or web_search_count >= 1: result.action = ReasoningAction.DIRECT_RESPONSE result.confidence = 0.95 - result.reasoning = "已获取信息,直接回答" + result.reasoning = "已获取足够信息,直接回答" return result # 策略1:尝试使用 LLM 推理 diff --git a/backend/app/main_graph/nodes/rag_nodes.py b/backend/app/main_graph/nodes/rag_nodes.py index a0bb68b..ea3889c 100644 --- a/backend/app/main_graph/nodes/rag_nodes.py +++ b/backend/app/main_graph/nodes/rag_nodes.py @@ -95,10 +95,10 @@ def inject_rag_tool_to_state(state: MainGraphState, rag_tool: Any) -> MainGraphS return state -# ========== RAG 检索核心逻辑(真正利用已有代码)========== -def _rag_retrieve_core(state: MainGraphState) -> MainGraphState: +# ========== RAG 检索核心逻辑(真正利用已有代码) ========== +async def _rag_retrieve_core(state: MainGraphState) -> MainGraphState: """ - RAG 检索核心逻辑(真正利用 rag/tools.py) + RAG 检索核心逻辑(真正利用 rag/tools.py) - 异步版本 Args: state: 主图状态 @@ -119,10 +119,10 @@ def _rag_retrieve_core(state: MainGraphState) -> MainGraphState: rag_tool = get_rag_tool_from_state(state) if rag_tool: - # 使用真正的 RAG 工具(来自 rag/tools.py) + # 使用真正的 RAG 工具(来自 rag/tools.py)- 异步版本 try: - # 调用 LangChain Tool 的 invoke 方法 - rag_context = rag_tool.invoke(retrieval_query) + # 直接 await 异步工具的 ainvoke 方法 + rag_context = await rag_tool.ainvoke(retrieval_query) state.rag_context = rag_context state.rag_docs = [ {"source": "rag_retrieval", "content": rag_context} @@ -134,9 +134,9 @@ def _rag_retrieve_core(state: MainGraphState) -> MainGraphState: except Exception as e: raise RuntimeError(f"RAG 工具调用失败: {str(e)}") from e elif _GLOBAL_RAG_PIPELINE: - # 使用 RAG Pipeline 直接检索 + # 使用 RAG Pipeline 直接检索 - 直接用异步方法 try: - documents = _GLOBAL_RAG_PIPELINE.retrieve(retrieval_query) + documents = await _GLOBAL_RAG_PIPELINE.aretrieve(retrieval_query) if documents: rag_context = _GLOBAL_RAG_PIPELINE.format_context(documents) state.rag_context = rag_context @@ -158,7 +158,7 @@ def _rag_retrieve_core(state: MainGraphState) -> MainGraphState: raise RuntimeError("RAG 工具未初始化,请先调用 set_global_rag_tool() 或 set_global_rag_pipeline()") -# ========== RAG 检索节点(带超时和重试)========== +# ========== RAG 检索节点(带超时和重试) ========== async def rag_retrieve_node(state: MainGraphState, config: Optional[Dict[str, Any]] = None) -> MainGraphState: """ RAG 检索节点:带超时和重试,真正利用已有 RAG 代码 @@ -196,8 +196,13 @@ async def rag_retrieve_node(state: MainGraphState, config: Optional[Dict[str, An for attempt in range(RAG_RETRY_CONFIG.max_retries + 1): try: - # 执行核心逻辑 - result = _rag_retrieve_core(state) + # 执行核心逻辑 - 异步 await + result = await _rag_retrieve_core(state) + + info(f"[rag_retrieve_node] RAG 检索成功,获取到上下文长度: {len(result.rag_context)} 字符") + if result.rag_docs: + for i, doc in enumerate(result.rag_docs[:3]): # 只显示前3条 + info(f"[rag_retrieve_node] 文档 {i+1}: {doc.get('content', '')[:100]}...") # 成功 state.debug_info["rag_retrieval"] = { @@ -226,6 +231,15 @@ async def rag_retrieve_node(state: MainGraphState, config: Optional[Dict[str, An except Exception as e: info(f"[rag_retrieve_node] 无法发送完成事件: {e}") + # 关键修复:把 rag_retrieve 加到 reasoning_history 里,让下次推理知道 + state.reasoning_history.append({ + "step": state.reasoning_step, + "action": "rag_retrieve", + "confidence": 1.0, + "reasoning": "RAG 检索完成", + "timestamp": datetime.now().isoformat() + }) + return result except Exception as e: @@ -255,7 +269,7 @@ async def rag_retrieve_node(state: MainGraphState, config: Optional[Dict[str, An # 指数退避等待 delay = RAG_RETRY_CONFIG.base_delay * (2 ** attempt) - time.sleep(min(delay, RAG_RETRY_CONFIG.max_delay)) + await asyncio.sleep(min(delay, RAG_RETRY_CONFIG.max_delay)) # 所有重试都失败,记录结构化错误 error_record = ErrorRecord( diff --git a/backend/app/main_graph/nodes/react_nodes.py b/backend/app/main_graph/nodes/react_nodes.py index 07da816..7e6f102 100644 --- a/backend/app/main_graph/nodes/react_nodes.py +++ b/backend/app/main_graph/nodes/react_nodes.py @@ -364,20 +364,27 @@ def route_by_reasoning(state: MainGraphState) -> str: if "subgraph_completed" in previous_actions or state.final_result: return "llm_call" - # 检查是否刚刚执行完 rag 或 web search,应该继续推理一次然后去 llm_call - # 但为了避免死循环,我们设置一个简单的规则 - if len(previous_actions) > 3: + # 关键修复:如果已经执行过 rag_retrieve 并且又执行过推理,直接去 LLM_CALL + # 这样的流程:推理1 → RAG → 推理2(判断 RAG 结果) → LLM_CALL + rag_count = previous_actions.count("rag_retrieve") + if rag_count >= 1 and len(previous_actions) >= rag_count + 1: + info(f"[route_by_reasoning] 已完成 RAG 检索和结果判断,直接去 llm_call") return "llm_call" - + + # 关键修复:限制最多 3 次推理,避免无限循环 + if len(previous_actions) >= 3: + info(f"[route_by_reasoning] 已达到最大推理次数 ({len(previous_actions)}),直接去 llm_call") + return "llm_call" + # 获取推理结果 reasoning_result: Optional[ReasoningResult] = state.debug_info.get("reasoning_result") - + if not reasoning_result: return "llm_call" - + # 使用 intent.py 提供的路由函数 route = get_route_by_reasoning(reasoning_result) - + # 映射到我们的节点名称 # 注意:这些名称必须与 main_graph_builder.py 中定义的节点名称一致 route_mapping = { @@ -391,7 +398,8 @@ def route_by_reasoning(state: MainGraphState) -> str: "dictionary": "dictionary_subgraph", "news_analysis": "news_analysis_subgraph", } - + + info(f"[route_by_reasoning] 推理结果={reasoning_result.action.name}, 路由={route_mapping.get(route, 'llm_call')}, 历史动作={previous_actions}") return route_mapping.get(route, "llm_call") diff --git a/backend/app/main_graph/utils/rag_initializer.py b/backend/app/main_graph/utils/rag_initializer.py index 5ab6833..f83ccca 100644 --- a/backend/app/main_graph/utils/rag_initializer.py +++ b/backend/app/main_graph/utils/rag_initializer.py @@ -1,5 +1,5 @@ # app/rag_initializer.py -from app.rag.tools import create_rag_tool_sync +from app.rag.tools import create_rag_tool_sync, create_rag_tool_async from rag_core import create_parent_retriever from app.model_services import get_embedding_service from app.logger import info, warning @@ -16,11 +16,11 @@ async def init_rag_tool(local_llm_creator): embeddings=embeddings ) rewrite_llm = local_llm_creator() - rag_tool = create_rag_tool_sync( + rag_tool = create_rag_tool_async( retriever, rewrite_llm, num_queries=3, rerank_top_n=5 ) - info("✅ RAG 检索工具初始化成功") + info("✅ RAG 检索工具初始化成功(异步版本)") return rag_tool except Exception as e: warning(f"⚠️ RAG 检索工具初始化失败: {e}") diff --git a/backend/app/rag/tools.py b/backend/app/rag/tools.py index 1daec9b..ee1b0d4 100644 --- a/backend/app/rag/tools.py +++ b/backend/app/rag/tools.py @@ -70,6 +70,63 @@ def create_rag_tool_sync( return search_knowledge_base_sync +def create_rag_tool_async( + retriever: Optional[BaseRetriever] = None, + llm: Optional[BaseLanguageModel] = None, + num_queries: int = 3, + rerank_top_n: int = 5, + collection_name: str = "rag_documents", +) -> Callable: + """ + 创建一个配置好的 RAG 检索工具(异步版本)。 + + 默认使用混合检索(稠密+BM25稀疏)+ 父子文档模式。 + + Args: + retriever: 基础检索器对象(可选,不提供则自动创建) + llm: 用于生成多路查询的语言模型(可选) + num_queries: 生成的查询变体数量 + rerank_top_n: 最终返回的文档数量 + collection_name: Qdrant 集合名称 + + Returns: + Async LangChain Tool 函数 + """ + pipeline = RAGPipeline( + retriever=retriever, + llm=llm, + num_queries=num_queries, + rerank_top_n=rerank_top_n, + collection_name=collection_name, + ) + + @tool + async def search_knowledge_base_async(query: str) -> str: + """ + 在知识库中搜索与查询相关的文档片段(异步版本)。 + + 使用混合检索(稠密向量语义 + BM25 关键词)+ 父子文档模式, + 检索效果最优。 + + Args: + query: 用户提出的问题或查询字符串 + + Returns: + 格式化后的相关文档内容 + """ + try: + documents = await pipeline.aretrieve(query) + if not documents: + return f"在知识库 '{collection_name}' 中未找到与 '{query}' 相关的信息。" + + context = pipeline.format_context(documents) + return context + except Exception as e: + return f"检索过程中发生错误: {str(e)}" + + return search_knowledge_base_async + + def create_rag_tool( collection_name: str = "rag_documents", llm: Optional[BaseLanguageModel] = None, diff --git a/frontend/src/config.py b/frontend/src/config.py index 2ab205a..8a9f867 100644 --- a/frontend/src/config.py +++ b/frontend/src/config.py @@ -51,7 +51,7 @@ class FrontendConfig: layout: str = "wide" # ==================== 模型配置(固定值,无需环境变量) ==================== - default_model: str = "zhipu" + default_model: str = "local" model_options: Optional[dict] = None # ==================== 用户配置(固定值,无需环境变量) ==================== @@ -73,7 +73,7 @@ class FrontendConfig: if self.model_options is None: self.model_options = { "zhipu": "智谱 GLM-5.1(在线)", - "local": "本地 llama.cpp(Gemma-4)", + "local": "本地 llama.cpp(Qwen3.5-9B)", "deepseek": "DeepSeek V4-Pro(在线)" } diff --git a/tools/test/check_qdrant.py b/tools/test/check_qdrant.py new file mode 100644 index 0000000..42483bf --- /dev/null +++ b/tools/test/check_qdrant.py @@ -0,0 +1,80 @@ +#!/usr/bin/env python3 +""" +检查 Qdrant 集合里的数据结构 +""" + +import asyncio +import os +import sys + +# 添加项目根目录到 Python 路径 +project_root = os.path.join(os.path.dirname(__file__), "..", "..") +sys.path.insert(0, os.path.join(project_root, "backend")) +sys.path.insert(0, project_root) + +from rag_core import QdrantVectorStore +from app.model_services import get_embedding_service + + +def check_qdrant_data(): + """检查 Qdrant 中的数据结构""" + print("="*70) + print("检查 Qdrant 中的数据结构...") + print("="*70) + + embeddings = get_embedding_service() + vs = QdrantVectorStore(collection_name="rag_documents", embeddings=embeddings) + client = vs.get_qdrant_client() + + # 先获取几个点看看 payload 结构 + print("\n获取 5 个随机文档:") + results = client.scroll( + collection_name="rag_documents", + limit=5, + with_payload=True, + with_vectors=True + ) + + for i, point in enumerate(results[0], 1): + print(f"\n{i}. ID: {point.id}") + print(f" Payload: {point.payload}") + print(f" Payload 键: {list(point.payload.keys())}") + if "text" in point.payload: + text = point.payload["text"] + print(f" Text 长度: {len(text)}") + print(f" Text 预览: {text[:150]}...") + if "page_content" in point.payload: + print(f" page_content: {point.payload['page_content'][:150]}...") + + # 看看向量 + if point.vector: + print(f" 向量存在: {type(point.vector)}") + if isinstance(point.vector, dict): + print(f" 向量键: {list(point.vector.keys())}") + + +def check_sparse_embedder(): + """检查稀疏嵌入器""" + from rag_core import get_sparse_embedder + + print("\n" + "="*70) + print("检查稀疏嵌入器...") + print("="*70) + + sparse_embedder = get_sparse_embedder() + + print(f"\n稀疏嵌入器: {sparse_embedder}") + print(f"Vocabulary 大小: {len(sparse_embedder.model.vocab)}") + print(f"示例查询: '冬天 食物'") + + # 用中文试试 + sparse_vec = sparse_embedder.embed_query("冬天 食物") + print(f"\n生成的稀疏向量:") + print(f" 索引数量: {len(sparse_vec['indices'])}") + print(f" 索引: {sparse_vec['indices'][:10]}") + print(f" 值: {sparse_vec['values'][:10]}") + + +if __name__ == "__main__": + check_qdrant_data() + check_sparse_embedder() diff --git a/tools/test/quick_test.py b/tools/test/quick_test.py new file mode 100644 index 0000000..3214014 --- /dev/null +++ b/tools/test/quick_test.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +""" +简单测试脚本:测试文档里真正有的内容 +""" + +import asyncio +import os +import sys + +project_root = os.path.join(os.path.dirname(__file__), "..", "..") +sys.path.insert(0, os.path.join(project_root, "backend")) + +from qdrant_client import models +from rag_core import QdrantVectorStore, get_sparse_embedder +from app.model_services import get_embedding_service + + +def test_dense_retrieval(): + """测试稠密检索""" + print("="*70) + print("测试稠密检索...") + print("="*70) + + embeddings = get_embedding_service() + vs = QdrantVectorStore(collection_name="rag_documents", embeddings=embeddings) + + query = "黄双银" # 用文档里真正有的名字查询 + print(f"\n查询: {query}") + + results = vs.similarity_search(query, k=3) + + print(f"\n找到 {len(results)} 个结果\n") + for i, doc in enumerate(results): + print(f"--- 结果 {i+1} ---") + print(doc.page_content[:200]) + print() + + +if __name__ == "__main__": + test_dense_retrieval() diff --git a/tools/test/reset_qdrant.py b/tools/test/reset_qdrant.py new file mode 100644 index 0000000..d08959f --- /dev/null +++ b/tools/test/reset_qdrant.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 +""" +删除 Qdrant 集合并重新索引 +""" + +import asyncio +import os +import sys + +# 添加项目根目录到 Python 路径 +project_root = os.path.join(os.path.dirname(__file__), "..", "..") +sys.path.insert(0, os.path.join(project_root, "backend")) +sys.path.insert(0, project_root) + +from rag_core import QdrantVectorStore +from app.model_services import get_embedding_service + + +async def delete_and_recreate(): + """删除并重新创建集合""" + print("="*70) + print("删除旧集合并重新创建...") + print("="*70) + + embeddings = get_embedding_service() + vs = QdrantVectorStore(collection_name="rag_documents", embeddings=embeddings) + + # 删除旧集合 + try: + vs.delete_collection() + print("✅ 旧集合已删除") + except Exception as e: + print(f"⚠️ 删除集合时出错(可能不存在): {e}") + + # 重新创建 + vs.create_collection() + print("✅ 新集合已创建") + + +if __name__ == "__main__": + asyncio.run(delete_and_recreate()) diff --git a/tools/test/simple_delete.py b/tools/test/simple_delete.py new file mode 100644 index 0000000..7f5d60f --- /dev/null +++ b/tools/test/simple_delete.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python3 +""" +简单删除 Qdrant 集合 +""" + +import sys +import os + +project_root = os.path.join(os.path.dirname(__file__), "..", "..") +sys.path.insert(0, os.path.join(project_root, "backend")) + +from rag_core.client import create_qdrant_client + + +def delete_collection(): + print("="*70) + print("删除 rag_documents 集合...") + print("="*70) + + client = create_qdrant_client() + + try: + client.delete_collection("rag_documents") + print("✅ 删除成功") + except Exception as e: + print(f"⚠️ 删除失败: {e}") + + +if __name__ == "__main__": + delete_collection() diff --git a/tools/test/simple_test.py b/tools/test/simple_test.py new file mode 100644 index 0000000..24e532e --- /dev/null +++ b/tools/test/simple_test.py @@ -0,0 +1,153 @@ +#!/usr/bin/env python3 +""" +简单测试脚本:检查 Qdrant 内容,测试各种检索方式 +""" + +import asyncio +import os +import sys + +project_root = os.path.join(os.path.dirname(__file__), "..", "..") +sys.path.insert(0, os.path.join(project_root, "backend")) + +from qdrant_client import models +from rag_core import QdrantVectorStore, get_sparse_embedder +from app.model_services import get_embedding_service + + +def check_qdrant_content(): + """检查 Qdrant 里的内容""" + print("="*70) + print("检查 Qdrant 内容...") + print("="*70) + + embeddings = get_embedding_service() + vs = QdrantVectorStore(collection_name="rag_documents", embeddings=embeddings) + client = vs.get_qdrant_client() + + # 滚动获取前 5 个点 + points, _ = client.scroll( + collection_name="rag_documents", + limit=5, + with_payload=True, + with_vectors=False + ) + + print(f"\n找到 {len(points)} 个文档\n") + for i, point in enumerate(points): + print(f"--- 文档 {i+1} ---") + print(f"ID: {point.id}") + print(f"Payload 键: {list(point.payload.keys())}") + + # 打印完整 payload + for k, v in point.payload.items(): + if isinstance(v, str) and len(v) > 150: + v = v[:150] + "..." + print(f" {k}: {v}") + print() + + +def test_dense_retrieval(): + """测试稠密检索""" + print("="*70) + print("测试稠密检索...") + print("="*70) + + embeddings = get_embedding_service() + vs = QdrantVectorStore(collection_name="rag_documents", embeddings=embeddings) + + query = "蚂蚁" # 用中文查询 + print(f"\n查询: {query}") + + results = vs.similarity_search(query, k=3) + + print(f"\n找到 {len(results)} 个结果\n") + for i, doc in enumerate(results): + print(f"--- 结果 {i+1} ---") + print(doc.page_content[:200]) + print() + + +def test_sparse_retrieval(): + """测试稀疏检索""" + print("="*70) + print("测试稀疏检索(BM25)...") + print("="*70) + + embeddings = get_embedding_service() + vs = QdrantVectorStore(collection_name="rag_documents", embeddings=embeddings) + client = vs.get_qdrant_client() + sparse_embedder = get_sparse_embedder() + + query = "冬天" + print(f"\n查询: {query}") + + sparse_query = sparse_embedder.embed_query(query) + sparse_vec = models.SparseVector( + indices=sparse_query["indices"], + values=sparse_query["values"] + ) + + response = client.query_points( + collection_name="rag_documents", + query=sparse_vec, + using="sparse", + limit=3, + with_payload=True + ) + + print(f"\n找到 {len(response.points)} 个结果\n") + for i, point in enumerate(response.points): + print(f"--- 结果 {i+1} ---") + print(f"分数: {point.score:.4f}") + text = point.payload.get("page_content", point.payload.get("text", "")) + print(text[:200]) + print() + + +def test_hybrid_retrieval(): + """测试混合检索""" + print("="*70) + print("测试混合检索(稠密+稀疏 RRF 融合)...") + print("="*70) + + embeddings = get_embedding_service() + vs = QdrantVectorStore(collection_name="rag_documents", embeddings=embeddings) + client = vs.get_qdrant_client() + sparse_embedder = get_sparse_embedder() + + query = "蚂蚁和蚱蜢" + print(f"\n查询: {query}") + + dense_query = embeddings.embed_query(query) + sparse_query = sparse_embedder.embed_query(query) + sparse_vec = models.SparseVector( + indices=sparse_query["indices"], + values=sparse_query["values"] + ) + + response = client.query_points( + collection_name="rag_documents", + prefetch=[ + models.Prefetch(query=dense_query, using="dense", limit=3), + models.Prefetch(query=sparse_vec, using="sparse", limit=3) + ], + query=models.FusionQuery(fusion=models.Fusion.RRF), + limit=3, + with_payload=True + ) + + print(f"\n找到 {len(response.points)} 个结果\n") + for i, point in enumerate(response.points): + print(f"--- 结果 {i+1} ---") + print(f"分数: {point.score:.4f}") + text = point.payload.get("page_content", point.payload.get("text", "")) + print(text[:200]) + print() + + +if __name__ == "__main__": + check_qdrant_content() + test_dense_retrieval() + test_sparse_retrieval() + test_hybrid_retrieval() diff --git a/tools/test/test_retrievers.py b/tools/test/test_retrievers.py new file mode 100644 index 0000000..f13398d --- /dev/null +++ b/tools/test/test_retrievers.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +""" +测试 app/rag/retriever.py 里的混合检索函数 +""" + +import asyncio +import os +import sys + +project_root = os.path.join(os.path.dirname(__file__), "..", "..") +sys.path.insert(0, os.path.join(project_root, "backend")) + +from app.rag.retriever import create_hybrid_retriever, create_parent_hybrid_retriever + + +def test_hybrid_retriever(): + """测试混合检索器""" + print("="*70) + print("测试 HybridRetriever...") + print("="*70) + + retriever = create_hybrid_retriever(collection_name="rag_documents", search_k=3) + results = retriever.invoke("黄双银") + + print(f"\n找到 {len(results)} 个结果\n") + for i, doc in enumerate(results): + print(f"--- 结果 {i+1} ---") + print(doc.page_content[:200]) + print() + + +def test_parent_hybrid_retriever(): + """测试父子混合检索器""" + print("\n" + "="*70) + print("测试 ParentHybridRetriever...") + print("="*70) + + retriever = create_parent_hybrid_retriever( + collection_name="rag_documents", + search_k=3, + use_docstore=False + ) + results = retriever.invoke("黄双银") + + print(f"\n找到 {len(results)} 个结果\n") + for i, doc in enumerate(results): + print(f"--- 结果 {i+1} ---") + print(doc.page_content[:300]) + print() + + +if __name__ == "__main__": + test_hybrid_retriever() + test_parent_hybrid_retriever()