修改rag,实现混合检索
All checks were successful
构建并部署 AI Agent 服务 / deploy (push) Successful in 5m42s

This commit is contained in:
2026-05-04 04:28:32 +08:00
parent d0590240f9
commit 82dde7113e
15 changed files with 536 additions and 65 deletions

View File

@@ -70,7 +70,7 @@ class AIAgentService:
raise RuntimeError("没有可用的模型")
return self
async def process_message(self, message: str, thread_id: str, model: str = "zhipu", user_id: str = "default_user") -> dict:
async def process_message(self, message: str, thread_id: str, model: str = "local", user_id: str = "default_user") -> dict:
"""处理用户消息返回包含回复、token统计和耗时的字典"""
if model not in self.graphs:
# 回退到第一个可用模型
@@ -175,6 +175,8 @@ class AIAgentService:
try:
info(f"📡 开始调用 graph.astream()...")
chunk_count = 0
full_message_content = "" # 收集完整消息内容
async for chunk in graph.astream(
input_state,
config=config,
@@ -184,21 +186,11 @@ class AIAgentService:
):
chunk_count += 1
chunk_type = chunk["type"]
info(f"📦 收到第 {chunk_count} 个chunk, type: {chunk_type}")
processed_event = {}
if chunk_type == "messages":
message_chunk, metadata = chunk["data"]
node_name = metadata.get("langgraph_node", "unknown")
info(f"📨 处理消息chunk, node: {node_name}")
# 详细记录消息内容,看看这些 chunk 到底是什么
if hasattr(message_chunk, "content"):
content_preview = str(message_chunk.content)[:200]
info(f"📄 消息内容预览 ({len(content_preview)} chars): {repr(content_preview)}")
if hasattr(message_chunk, "type"):
info(f"📋 消息类型: {message_chunk.type}")
if hasattr(message_chunk, "tool_calls"):
info(f"🔧 包含工具调用: {message_chunk.tool_calls}")
# 检测节点变化,发送节点开始事件
if node_name != current_node:
@@ -218,8 +210,6 @@ class AIAgentService:
reasoning_token = ""
if hasattr(message_chunk, 'additional_kwargs'):
reasoning_token = message_chunk.additional_kwargs.get("reasoning_content", "")
info(f"💬 消息token: token_content='{repr(token_content[:50])}', reasoning_token='{repr(reasoning_token[:50])}', node_name='{node_name}'")
# 处理思考过程
if reasoning_token:
@@ -228,7 +218,6 @@ class AIAgentService:
"node": node_name,
"reasoning_token": reasoning_token
}
info(f"✅ 生成 reasoning_token 事件: {processed_event}")
# 处理工具调用
elif hasattr(message_chunk, 'tool_calls') and message_chunk.tool_calls:
for tool_call in message_chunk.tool_calls:
@@ -248,7 +237,7 @@ class AIAgentService:
"args": tool_args,
"id": tool_call_id
}
# 处理普通 token
# 处理普通 token - 只收集,不打印单个 token
elif token_content:
processed_event = {
"type": "llm_token",
@@ -256,18 +245,13 @@ class AIAgentService:
"token": token_content,
"reasoning_token": reasoning_token
}
info(f"✅ 生成 llm_token 事件: {processed_event}")
else:
info(f"⚠️ 没有生成任何事件token_content='{repr(token_content)}', reasoning_token='{repr(reasoning_token)}'")
if node_name == "llm_call":
full_message_content += token_content
elif chunk_type == "updates":
info(f"🔄 处理updates chunk")
updates_data = chunk["data"]
serialized_data = self._serialize_value(updates_data)
# 关键修复:不再从 updates 中读取 latest_reasoning避免重复
# 因为我们现在直接通过 custom 事件发送推理结果了
# 检查是否有人工审核请求
if "review_pending" in serialized_data and serialized_data["review_pending"]:
review_id = serialized_data.get("review_id", "")
@@ -302,18 +286,12 @@ class AIAgentService:
}
elif chunk_type == "custom":
info(f"🎯 处理custom chunk, 完整数据: {repr(chunk)}")
custom_data = chunk["data"]
info(f"🎯 custom_data 内容: {repr(custom_data)}")
info(f"🎯 custom_data 类型: {type(custom_data)}")
# 关键修复:处理我们从 react_reason_node 发送的自定义推理事件
# LangGraph 的 adispatch_custom_event 发送的事件格式:
# chunk["data"] 是我们传的第二个参数dict
# 处理我们从 react_reason_node 发送的自定义推理事件
if isinstance(custom_data, dict):
# 检查是否是我们的推理事件
if "action" in custom_data and "reasoning" in custom_data:
info(f"[Agent Service] 收到自定义推理事件: {custom_data}")
yield {
"type": "react_reasoning",
"step": custom_data.get("step", 1),
@@ -339,7 +317,10 @@ class AIAgentService:
if processed_event:
yield processed_event
# 完整消息集合完成后,一次性打印
info(f"✅ graph.astream() 完成,共 {chunk_count} 个chunks")
if full_message_content:
info(f"📄 完整消息内容: {repr(full_message_content)}")
except Exception as e:
error(f"❌ 执行 React 图时出错: {e}")

View File

@@ -12,18 +12,21 @@ class ThreadHistoryService:
def __init__(self, checkpointer):
self.checkpointer = checkpointer
async def get_user_threads(self, user_id: str, limit: int = 50) -> List[Dict[str, Any]]:
async def get_user_threads(self, user_id: str, limit: int = 4) -> List[Dict[str, Any]]:
"""
获取指定用户的所有线程摘要信息
Args:
user_id: 用户 ID
limit: 返回数量限制
limit: 返回数量限制强制最多4条
Returns:
线程列表,每个包含 thread_id, last_updated, summary, message_count
"""
try:
# 强制限制最多4条
actual_limit = min(limit, 4)
# 查询 checkpoints 表获取用户的线程列表
async with self.checkpointer.conn.cursor() as cur:
# 在较新的 LangGraph 版本中AsyncPostgresSaver 创建的 checkpoints 表
@@ -40,7 +43,7 @@ class ThreadHistoryService:
ORDER BY last_updated DESC
LIMIT %s
"""
await cur.execute(query, (user_id, limit))
await cur.execute(query, (user_id, actual_limit))
rows = await cur.fetchall()
threads = []

View File

@@ -98,7 +98,7 @@ async def health_check():
class ChatRequest(BaseModel):
message: str
thread_id: str | None = None
model: str = "zhipu"
model: str = "local"
user_id: str = "default_user"
class ChatResponse(BaseModel):
@@ -212,7 +212,7 @@ async def chat_endpoint(
@app.get("/threads")
async def list_threads(
user_id: str = Query("default_user", description="用户 ID"),
limit: int = Query(50, ge=1, le=200, description="返回数量限制"),
limit: int = Query(4, ge=1, le=200, description="返回数量限制"),
history_service: ThreadHistoryService = Depends(get_history_service)
):
"""获取当前用户的对话历史列表"""
@@ -312,7 +312,7 @@ async def websocket_endpoint(
data = await websocket.receive_json()
message = data.get("message")
thread_id = data.get("thread_id", str(uuid.uuid4()))
model = data.get("model", "zhipu")
model = data.get("model", "local")
user_id = data.get("user_id", "default_user")
if not message:
await websocket.send_json({"error": "missing message"})
@@ -435,4 +435,10 @@ if __name__ == "__main__":
import uvicorn
# 使用环境变量或默认端口 8079避免与 llama.cpp 的 8081 端口冲突)
port = int(BACKEND_PORT)
uvicorn.run(app, host="0.0.0.0", port=port)
uvicorn.run(
app,
host="0.0.0.0",
port=port,
log_level="debug",
access_log=True
)

View File

@@ -130,12 +130,16 @@ class ReactIntentReasoner:
retrieved_docs = context.get("retrieved_docs", [])
messages = context.get("messages", [])
# 关键修复 2如果已经有 rag_context 或 web_search_results通过 messages 推断),直接回答
# 检查是否已经执行过 rag_retrieve 或 web_search
if "rag_retrieve" in previous_actions or "web_search" in previous_actions:
# 关键修改:不要在第一次 rag_retrieve 后就直接回答,允许再推理一次
# 让推理逻辑有机会判断 RAG 结果好不好,要不要再检索或转 web search
rag_count = previous_actions.count("rag_retrieve")
web_search_count = previous_actions.count("web_search")
# 只有当 rag 或 web search 已经超过 1 次,或者已经有推理在 rag 之后,才直接回答
if rag_count >= 2 or web_search_count >= 1:
result.action = ReasoningAction.DIRECT_RESPONSE
result.confidence = 0.95
result.reasoning = "已获取信息,直接回答"
result.reasoning = "已获取足够信息,直接回答"
return result
# 策略1尝试使用 LLM 推理

View File

@@ -95,10 +95,10 @@ def inject_rag_tool_to_state(state: MainGraphState, rag_tool: Any) -> MainGraphS
return state
# ========== RAG 检索核心逻辑(真正利用已有代码)==========
def _rag_retrieve_core(state: MainGraphState) -> MainGraphState:
# ========== RAG 检索核心逻辑(真正利用已有代码) ==========
async def _rag_retrieve_core(state: MainGraphState) -> MainGraphState:
"""
RAG 检索核心逻辑(真正利用 rag/tools.py
RAG 检索核心逻辑(真正利用 rag/tools.py - 异步版本
Args:
state: 主图状态
@@ -119,10 +119,10 @@ def _rag_retrieve_core(state: MainGraphState) -> MainGraphState:
rag_tool = get_rag_tool_from_state(state)
if rag_tool:
# 使用真正的 RAG 工具(来自 rag/tools.py
# 使用真正的 RAG 工具(来自 rag/tools.py- 异步版本
try:
# 调用 LangChain Tool 的 invoke 方法
rag_context = rag_tool.invoke(retrieval_query)
# 直接 await 异步工具ainvoke 方法
rag_context = await rag_tool.ainvoke(retrieval_query)
state.rag_context = rag_context
state.rag_docs = [
{"source": "rag_retrieval", "content": rag_context}
@@ -134,9 +134,9 @@ def _rag_retrieve_core(state: MainGraphState) -> MainGraphState:
except Exception as e:
raise RuntimeError(f"RAG 工具调用失败: {str(e)}") from e
elif _GLOBAL_RAG_PIPELINE:
# 使用 RAG Pipeline 直接检索
# 使用 RAG Pipeline 直接检索 - 直接用异步方法
try:
documents = _GLOBAL_RAG_PIPELINE.retrieve(retrieval_query)
documents = await _GLOBAL_RAG_PIPELINE.aretrieve(retrieval_query)
if documents:
rag_context = _GLOBAL_RAG_PIPELINE.format_context(documents)
state.rag_context = rag_context
@@ -158,7 +158,7 @@ def _rag_retrieve_core(state: MainGraphState) -> MainGraphState:
raise RuntimeError("RAG 工具未初始化,请先调用 set_global_rag_tool() 或 set_global_rag_pipeline()")
# ========== RAG 检索节点(带超时和重试)==========
# ========== RAG 检索节点(带超时和重试) ==========
async def rag_retrieve_node(state: MainGraphState, config: Optional[Dict[str, Any]] = None) -> MainGraphState:
"""
RAG 检索节点:带超时和重试,真正利用已有 RAG 代码
@@ -196,8 +196,13 @@ async def rag_retrieve_node(state: MainGraphState, config: Optional[Dict[str, An
for attempt in range(RAG_RETRY_CONFIG.max_retries + 1):
try:
# 执行核心逻辑
result = _rag_retrieve_core(state)
# 执行核心逻辑 - 异步 await
result = await _rag_retrieve_core(state)
info(f"[rag_retrieve_node] RAG 检索成功,获取到上下文长度: {len(result.rag_context)} 字符")
if result.rag_docs:
for i, doc in enumerate(result.rag_docs[:3]): # 只显示前3条
info(f"[rag_retrieve_node] 文档 {i+1}: {doc.get('content', '')[:100]}...")
# 成功
state.debug_info["rag_retrieval"] = {
@@ -226,6 +231,15 @@ async def rag_retrieve_node(state: MainGraphState, config: Optional[Dict[str, An
except Exception as e:
info(f"[rag_retrieve_node] 无法发送完成事件: {e}")
# 关键修复:把 rag_retrieve 加到 reasoning_history 里,让下次推理知道
state.reasoning_history.append({
"step": state.reasoning_step,
"action": "rag_retrieve",
"confidence": 1.0,
"reasoning": "RAG 检索完成",
"timestamp": datetime.now().isoformat()
})
return result
except Exception as e:
@@ -255,7 +269,7 @@ async def rag_retrieve_node(state: MainGraphState, config: Optional[Dict[str, An
# 指数退避等待
delay = RAG_RETRY_CONFIG.base_delay * (2 ** attempt)
time.sleep(min(delay, RAG_RETRY_CONFIG.max_delay))
await asyncio.sleep(min(delay, RAG_RETRY_CONFIG.max_delay))
# 所有重试都失败,记录结构化错误
error_record = ErrorRecord(

View File

@@ -364,20 +364,27 @@ def route_by_reasoning(state: MainGraphState) -> str:
if "subgraph_completed" in previous_actions or state.final_result:
return "llm_call"
# 检查是否刚刚执行 rag 或 web search应该继续推理一次然后去 llm_call
# 但为了避免死循环,我们设置一个简单的规则
if len(previous_actions) > 3:
# 关键修复:如果已经执行 rag_retrieve 并且又执行过推理,直接去 LLM_CALL
# 这样的流程推理1 → RAG → 推理2判断 RAG 结果) → LLM_CALL
rag_count = previous_actions.count("rag_retrieve")
if rag_count >= 1 and len(previous_actions) >= rag_count + 1:
info(f"[route_by_reasoning] 已完成 RAG 检索和结果判断,直接去 llm_call")
return "llm_call"
# 关键修复:限制最多 3 次推理,避免无限循环
if len(previous_actions) >= 3:
info(f"[route_by_reasoning] 已达到最大推理次数 ({len(previous_actions)}),直接去 llm_call")
return "llm_call"
# 获取推理结果
reasoning_result: Optional[ReasoningResult] = state.debug_info.get("reasoning_result")
if not reasoning_result:
return "llm_call"
# 使用 intent.py 提供的路由函数
route = get_route_by_reasoning(reasoning_result)
# 映射到我们的节点名称
# 注意:这些名称必须与 main_graph_builder.py 中定义的节点名称一致
route_mapping = {
@@ -391,7 +398,8 @@ def route_by_reasoning(state: MainGraphState) -> str:
"dictionary": "dictionary_subgraph",
"news_analysis": "news_analysis_subgraph",
}
info(f"[route_by_reasoning] 推理结果={reasoning_result.action.name}, 路由={route_mapping.get(route, 'llm_call')}, 历史动作={previous_actions}")
return route_mapping.get(route, "llm_call")

View File

@@ -1,5 +1,5 @@
# app/rag_initializer.py
from app.rag.tools import create_rag_tool_sync
from app.rag.tools import create_rag_tool_sync, create_rag_tool_async
from rag_core import create_parent_retriever
from app.model_services import get_embedding_service
from app.logger import info, warning
@@ -16,11 +16,11 @@ async def init_rag_tool(local_llm_creator):
embeddings=embeddings
)
rewrite_llm = local_llm_creator()
rag_tool = create_rag_tool_sync(
rag_tool = create_rag_tool_async(
retriever, rewrite_llm,
num_queries=3, rerank_top_n=5
)
info("✅ RAG 检索工具初始化成功")
info("✅ RAG 检索工具初始化成功(异步版本)")
return rag_tool
except Exception as e:
warning(f"⚠️ RAG 检索工具初始化失败: {e}")

View File

@@ -70,6 +70,63 @@ def create_rag_tool_sync(
return search_knowledge_base_sync
def create_rag_tool_async(
retriever: Optional[BaseRetriever] = None,
llm: Optional[BaseLanguageModel] = None,
num_queries: int = 3,
rerank_top_n: int = 5,
collection_name: str = "rag_documents",
) -> Callable:
"""
创建一个配置好的 RAG 检索工具(异步版本)。
默认使用混合检索(稠密+BM25稀疏+ 父子文档模式。
Args:
retriever: 基础检索器对象(可选,不提供则自动创建)
llm: 用于生成多路查询的语言模型(可选)
num_queries: 生成的查询变体数量
rerank_top_n: 最终返回的文档数量
collection_name: Qdrant 集合名称
Returns:
Async LangChain Tool 函数
"""
pipeline = RAGPipeline(
retriever=retriever,
llm=llm,
num_queries=num_queries,
rerank_top_n=rerank_top_n,
collection_name=collection_name,
)
@tool
async def search_knowledge_base_async(query: str) -> str:
"""
在知识库中搜索与查询相关的文档片段(异步版本)。
使用混合检索(稠密向量语义 + BM25 关键词)+ 父子文档模式,
检索效果最优。
Args:
query: 用户提出的问题或查询字符串
Returns:
格式化后的相关文档内容
"""
try:
documents = await pipeline.aretrieve(query)
if not documents:
return f"在知识库 '{collection_name}' 中未找到与 '{query}' 相关的信息。"
context = pipeline.format_context(documents)
return context
except Exception as e:
return f"检索过程中发生错误: {str(e)}"
return search_knowledge_base_async
def create_rag_tool(
collection_name: str = "rag_documents",
llm: Optional[BaseLanguageModel] = None,

View File

@@ -51,7 +51,7 @@ class FrontendConfig:
layout: str = "wide"
# ==================== 模型配置(固定值,无需环境变量) ====================
default_model: str = "zhipu"
default_model: str = "local"
model_options: Optional[dict] = None
# ==================== 用户配置(固定值,无需环境变量) ====================
@@ -73,7 +73,7 @@ class FrontendConfig:
if self.model_options is None:
self.model_options = {
"zhipu": "智谱 GLM-5.1(在线)",
"local": "本地 llama.cppGemma-4",
"local": "本地 llama.cppQwen3.5-9B",
"deepseek": "DeepSeek V4-Pro在线"
}

View File

@@ -0,0 +1,80 @@
#!/usr/bin/env python3
"""
检查 Qdrant 集合里的数据结构
"""
import asyncio
import os
import sys
# 添加项目根目录到 Python 路径
project_root = os.path.join(os.path.dirname(__file__), "..", "..")
sys.path.insert(0, os.path.join(project_root, "backend"))
sys.path.insert(0, project_root)
from rag_core import QdrantVectorStore
from app.model_services import get_embedding_service
def check_qdrant_data():
"""检查 Qdrant 中的数据结构"""
print("="*70)
print("检查 Qdrant 中的数据结构...")
print("="*70)
embeddings = get_embedding_service()
vs = QdrantVectorStore(collection_name="rag_documents", embeddings=embeddings)
client = vs.get_qdrant_client()
# 先获取几个点看看 payload 结构
print("\n获取 5 个随机文档:")
results = client.scroll(
collection_name="rag_documents",
limit=5,
with_payload=True,
with_vectors=True
)
for i, point in enumerate(results[0], 1):
print(f"\n{i}. ID: {point.id}")
print(f" Payload: {point.payload}")
print(f" Payload 键: {list(point.payload.keys())}")
if "text" in point.payload:
text = point.payload["text"]
print(f" Text 长度: {len(text)}")
print(f" Text 预览: {text[:150]}...")
if "page_content" in point.payload:
print(f" page_content: {point.payload['page_content'][:150]}...")
# 看看向量
if point.vector:
print(f" 向量存在: {type(point.vector)}")
if isinstance(point.vector, dict):
print(f" 向量键: {list(point.vector.keys())}")
def check_sparse_embedder():
"""检查稀疏嵌入器"""
from rag_core import get_sparse_embedder
print("\n" + "="*70)
print("检查稀疏嵌入器...")
print("="*70)
sparse_embedder = get_sparse_embedder()
print(f"\n稀疏嵌入器: {sparse_embedder}")
print(f"Vocabulary 大小: {len(sparse_embedder.model.vocab)}")
print(f"示例查询: '冬天 食物'")
# 用中文试试
sparse_vec = sparse_embedder.embed_query("冬天 食物")
print(f"\n生成的稀疏向量:")
print(f" 索引数量: {len(sparse_vec['indices'])}")
print(f" 索引: {sparse_vec['indices'][:10]}")
print(f" 值: {sparse_vec['values'][:10]}")
if __name__ == "__main__":
check_qdrant_data()
check_sparse_embedder()

40
tools/test/quick_test.py Normal file
View File

@@ -0,0 +1,40 @@
#!/usr/bin/env python3
"""
简单测试脚本:测试文档里真正有的内容
"""
import asyncio
import os
import sys
project_root = os.path.join(os.path.dirname(__file__), "..", "..")
sys.path.insert(0, os.path.join(project_root, "backend"))
from qdrant_client import models
from rag_core import QdrantVectorStore, get_sparse_embedder
from app.model_services import get_embedding_service
def test_dense_retrieval():
"""测试稠密检索"""
print("="*70)
print("测试稠密检索...")
print("="*70)
embeddings = get_embedding_service()
vs = QdrantVectorStore(collection_name="rag_documents", embeddings=embeddings)
query = "黄双银" # 用文档里真正有的名字查询
print(f"\n查询: {query}")
results = vs.similarity_search(query, k=3)
print(f"\n找到 {len(results)} 个结果\n")
for i, doc in enumerate(results):
print(f"--- 结果 {i+1} ---")
print(doc.page_content[:200])
print()
if __name__ == "__main__":
test_dense_retrieval()

View File

@@ -0,0 +1,41 @@
#!/usr/bin/env python3
"""
删除 Qdrant 集合并重新索引
"""
import asyncio
import os
import sys
# 添加项目根目录到 Python 路径
project_root = os.path.join(os.path.dirname(__file__), "..", "..")
sys.path.insert(0, os.path.join(project_root, "backend"))
sys.path.insert(0, project_root)
from rag_core import QdrantVectorStore
from app.model_services import get_embedding_service
async def delete_and_recreate():
"""删除并重新创建集合"""
print("="*70)
print("删除旧集合并重新创建...")
print("="*70)
embeddings = get_embedding_service()
vs = QdrantVectorStore(collection_name="rag_documents", embeddings=embeddings)
# 删除旧集合
try:
vs.delete_collection()
print("✅ 旧集合已删除")
except Exception as e:
print(f"⚠️ 删除集合时出错(可能不存在): {e}")
# 重新创建
vs.create_collection()
print("✅ 新集合已创建")
if __name__ == "__main__":
asyncio.run(delete_and_recreate())

View File

@@ -0,0 +1,30 @@
#!/usr/bin/env python3
"""
简单删除 Qdrant 集合
"""
import sys
import os
project_root = os.path.join(os.path.dirname(__file__), "..", "..")
sys.path.insert(0, os.path.join(project_root, "backend"))
from rag_core.client import create_qdrant_client
def delete_collection():
print("="*70)
print("删除 rag_documents 集合...")
print("="*70)
client = create_qdrant_client()
try:
client.delete_collection("rag_documents")
print("✅ 删除成功")
except Exception as e:
print(f"⚠️ 删除失败: {e}")
if __name__ == "__main__":
delete_collection()

153
tools/test/simple_test.py Normal file
View File

@@ -0,0 +1,153 @@
#!/usr/bin/env python3
"""
简单测试脚本:检查 Qdrant 内容,测试各种检索方式
"""
import asyncio
import os
import sys
project_root = os.path.join(os.path.dirname(__file__), "..", "..")
sys.path.insert(0, os.path.join(project_root, "backend"))
from qdrant_client import models
from rag_core import QdrantVectorStore, get_sparse_embedder
from app.model_services import get_embedding_service
def check_qdrant_content():
"""检查 Qdrant 里的内容"""
print("="*70)
print("检查 Qdrant 内容...")
print("="*70)
embeddings = get_embedding_service()
vs = QdrantVectorStore(collection_name="rag_documents", embeddings=embeddings)
client = vs.get_qdrant_client()
# 滚动获取前 5 个点
points, _ = client.scroll(
collection_name="rag_documents",
limit=5,
with_payload=True,
with_vectors=False
)
print(f"\n找到 {len(points)} 个文档\n")
for i, point in enumerate(points):
print(f"--- 文档 {i+1} ---")
print(f"ID: {point.id}")
print(f"Payload 键: {list(point.payload.keys())}")
# 打印完整 payload
for k, v in point.payload.items():
if isinstance(v, str) and len(v) > 150:
v = v[:150] + "..."
print(f" {k}: {v}")
print()
def test_dense_retrieval():
"""测试稠密检索"""
print("="*70)
print("测试稠密检索...")
print("="*70)
embeddings = get_embedding_service()
vs = QdrantVectorStore(collection_name="rag_documents", embeddings=embeddings)
query = "蚂蚁" # 用中文查询
print(f"\n查询: {query}")
results = vs.similarity_search(query, k=3)
print(f"\n找到 {len(results)} 个结果\n")
for i, doc in enumerate(results):
print(f"--- 结果 {i+1} ---")
print(doc.page_content[:200])
print()
def test_sparse_retrieval():
"""测试稀疏检索"""
print("="*70)
print("测试稀疏检索BM25...")
print("="*70)
embeddings = get_embedding_service()
vs = QdrantVectorStore(collection_name="rag_documents", embeddings=embeddings)
client = vs.get_qdrant_client()
sparse_embedder = get_sparse_embedder()
query = "冬天"
print(f"\n查询: {query}")
sparse_query = sparse_embedder.embed_query(query)
sparse_vec = models.SparseVector(
indices=sparse_query["indices"],
values=sparse_query["values"]
)
response = client.query_points(
collection_name="rag_documents",
query=sparse_vec,
using="sparse",
limit=3,
with_payload=True
)
print(f"\n找到 {len(response.points)} 个结果\n")
for i, point in enumerate(response.points):
print(f"--- 结果 {i+1} ---")
print(f"分数: {point.score:.4f}")
text = point.payload.get("page_content", point.payload.get("text", ""))
print(text[:200])
print()
def test_hybrid_retrieval():
"""测试混合检索"""
print("="*70)
print("测试混合检索(稠密+稀疏 RRF 融合)...")
print("="*70)
embeddings = get_embedding_service()
vs = QdrantVectorStore(collection_name="rag_documents", embeddings=embeddings)
client = vs.get_qdrant_client()
sparse_embedder = get_sparse_embedder()
query = "蚂蚁和蚱蜢"
print(f"\n查询: {query}")
dense_query = embeddings.embed_query(query)
sparse_query = sparse_embedder.embed_query(query)
sparse_vec = models.SparseVector(
indices=sparse_query["indices"],
values=sparse_query["values"]
)
response = client.query_points(
collection_name="rag_documents",
prefetch=[
models.Prefetch(query=dense_query, using="dense", limit=3),
models.Prefetch(query=sparse_vec, using="sparse", limit=3)
],
query=models.FusionQuery(fusion=models.Fusion.RRF),
limit=3,
with_payload=True
)
print(f"\n找到 {len(response.points)} 个结果\n")
for i, point in enumerate(response.points):
print(f"--- 结果 {i+1} ---")
print(f"分数: {point.score:.4f}")
text = point.payload.get("page_content", point.payload.get("text", ""))
print(text[:200])
print()
if __name__ == "__main__":
check_qdrant_content()
test_dense_retrieval()
test_sparse_retrieval()
test_hybrid_retrieval()

View File

@@ -0,0 +1,54 @@
#!/usr/bin/env python3
"""
测试 app/rag/retriever.py 里的混合检索函数
"""
import asyncio
import os
import sys
project_root = os.path.join(os.path.dirname(__file__), "..", "..")
sys.path.insert(0, os.path.join(project_root, "backend"))
from app.rag.retriever import create_hybrid_retriever, create_parent_hybrid_retriever
def test_hybrid_retriever():
"""测试混合检索器"""
print("="*70)
print("测试 HybridRetriever...")
print("="*70)
retriever = create_hybrid_retriever(collection_name="rag_documents", search_k=3)
results = retriever.invoke("黄双银")
print(f"\n找到 {len(results)} 个结果\n")
for i, doc in enumerate(results):
print(f"--- 结果 {i+1} ---")
print(doc.page_content[:200])
print()
def test_parent_hybrid_retriever():
"""测试父子混合检索器"""
print("\n" + "="*70)
print("测试 ParentHybridRetriever...")
print("="*70)
retriever = create_parent_hybrid_retriever(
collection_name="rag_documents",
search_k=3,
use_docstore=False
)
results = retriever.invoke("黄双银")
print(f"\n找到 {len(results)} 个结果\n")
for i, doc in enumerate(results):
print(f"--- 结果 {i+1} ---")
print(doc.page_content[:300])
print()
if __name__ == "__main__":
test_hybrid_retriever()
test_parent_hybrid_retriever()