refactor: 重构RAG核心组件,简化代码结构和测试文件
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 6m53s

This commit is contained in:
2026-05-04 17:58:10 +08:00
parent a07e398739
commit 9841f47432
31 changed files with 578 additions and 1496 deletions

View File

@@ -1,75 +0,0 @@
#!/usr/bin/env python3
"""
检查 Qdrant 集合里的数据结构
"""
import asyncio
import os
import sys
from backend.rag_core import QdrantHybridStore
from backend.app.model_services import get_embedding_service
def check_qdrant_data():
"""检查 Qdrant 中的数据结构"""
print("="*70)
print("检查 Qdrant 中的数据结构...")
print("="*70)
embeddings = get_embedding_service()
vs = QdrantHybridStore(collection_name="rag_documents", embeddings=embeddings)
client = vs.get_qdrant_client()
# 先获取几个点看看 payload 结构
print("\n获取 5 个随机文档:")
results = client.scroll(
collection_name="rag_documents",
limit=5,
with_payload=True,
with_vectors=True
)
for i, point in enumerate(results[0], 1):
print(f"\n{i}. ID: {point.id}")
print(f" Payload: {point.payload}")
print(f" Payload 键: {list(point.payload.keys())}")
if "text" in point.payload:
text = point.payload["text"]
print(f" Text 长度: {len(text)}")
print(f" Text 预览: {text[:150]}...")
if "page_content" in point.payload:
print(f" page_content: {point.payload['page_content'][:150]}...")
# 看看向量
if point.vector:
print(f" 向量存在: {type(point.vector)}")
if isinstance(point.vector, dict):
print(f" 向量键: {list(point.vector.keys())}")
def check_sparse_embedder():
"""检查稀疏嵌入器"""
from backend.rag_core import get_sparse_embedder
print("\n" + "="*70)
print("检查稀疏嵌入器...")
print("="*70)
sparse_embedder = get_sparse_embedder()
print(f"\n稀疏嵌入器: {sparse_embedder}")
print(f"Vocabulary 大小: {len(sparse_embedder.model.vocab)}")
print(f"示例查询: '冬天 食物'")
# 用中文试试
sparse_vec = sparse_embedder.embed_query("冬天 食物")
print(f"\n生成的稀疏向量:")
print(f" 索引数量: {len(sparse_vec['indices'])}")
print(f" 索引: {sparse_vec['indices'][:10]}")
print(f" 值: {sparse_vec['values'][:10]}")
if __name__ == "__main__":
check_qdrant_data()
check_sparse_embedder()

View File

@@ -1,37 +0,0 @@
#!/usr/bin/env python3
"""
简单测试脚本:测试文档里真正有的内容
"""
import asyncio
import os
import sys
from qdrant_client import models
from backend.rag_core import QdrantHybridStore, get_sparse_embedder
from backend.app.model_services import get_embedding_service
def test_dense_retrieval():
"""测试稠密检索"""
print("="*70)
print("测试稠密检索...")
print("="*70)
embeddings = get_embedding_service()
vs = QdrantHybridStore(collection_name="rag_documents", embeddings=embeddings)
query = "黄双银" # 用文档里真正有的名字查询
print(f"\n查询: {query}")
results = vs.similarity_search(query, k=3)
print(f"\n找到 {len(results)} 个结果\n")
for i, doc in enumerate(results):
print(f"--- 结果 {i+1} ---")
print(doc.page_content[:200])
print()
if __name__ == "__main__":
test_dense_retrieval()

View File

@@ -1,36 +0,0 @@
#!/usr/bin/env python3
"""
删除 Qdrant 集合并重新索引
"""
import asyncio
import os
import sys
from backend.rag_core import QdrantHybridStore
from backend.app.model_services import get_embedding_service
async def delete_and_recreate():
"""删除并重新创建集合"""
print("="*70)
print("删除旧集合并重新创建...")
print("="*70)
embeddings = get_embedding_service()
vs = QdrantHybridStore(collection_name="rag_documents", embeddings=embeddings)
# 删除旧集合
try:
vs.delete_collection()
print("✅ 旧集合已删除")
except Exception as e:
print(f"⚠️ 删除集合时出错(可能不存在): {e}")
# 重新创建
vs.create_collection()
print("✅ 新集合已创建")
if __name__ == "__main__":
asyncio.run(delete_and_recreate())

View File

@@ -1,27 +0,0 @@
#!/usr/bin/env python3
"""
简单删除 Qdrant 集合
"""
import sys
import os
from backend.rag_core.client import create_qdrant_client
def delete_collection():
print("="*70)
print("删除 rag_documents 集合...")
print("="*70)
client = create_qdrant_client()
try:
client.delete_collection("rag_documents")
print("✅ 删除成功")
except Exception as e:
print(f"⚠️ 删除失败: {e}")
if __name__ == "__main__":
delete_collection()

View File

@@ -1,150 +0,0 @@
#!/usr/bin/env python3
"""
简单测试脚本:检查 Qdrant 内容,测试各种检索方式
"""
import asyncio
import os
import sys
from qdrant_client import models
from backend.rag_core import QdrantHybridStore, get_sparse_embedder
from backend.app.model_services import get_embedding_service
def check_qdrant_content():
"""检查 Qdrant 里的内容"""
print("="*70)
print("检查 Qdrant 内容...")
print("="*70)
embeddings = get_embedding_service()
vs = QdrantHybridStore(collection_name="rag_documents", embeddings=embeddings)
client = vs.get_qdrant_client()
# 滚动获取前 5 个点
points, _ = client.scroll(
collection_name="rag_documents",
limit=5,
with_payload=True,
with_vectors=False
)
print(f"\n找到 {len(points)} 个文档\n")
for i, point in enumerate(points):
print(f"--- 文档 {i+1} ---")
print(f"ID: {point.id}")
print(f"Payload 键: {list(point.payload.keys())}")
# 打印完整 payload
for k, v in point.payload.items():
if isinstance(v, str) and len(v) > 150:
v = v[:150] + "..."
print(f" {k}: {v}")
print()
def test_dense_retrieval():
"""测试稠密检索"""
print("="*70)
print("测试稠密检索...")
print("="*70)
embeddings = get_embedding_service()
vs = QdrantHybridStore(collection_name="rag_documents", embeddings=embeddings)
query = "蚂蚁" # 用中文查询
print(f"\n查询: {query}")
results = vs.similarity_search(query, k=3)
print(f"\n找到 {len(results)} 个结果\n")
for i, doc in enumerate(results):
print(f"--- 结果 {i+1} ---")
print(doc.page_content[:200])
print()
def test_sparse_retrieval():
"""测试稀疏检索"""
print("="*70)
print("测试稀疏检索BM25...")
print("="*70)
embeddings = get_embedding_service()
vs = QdrantHybridStore(collection_name="rag_documents", embeddings=embeddings)
client = vs.get_qdrant_client()
sparse_embedder = get_sparse_embedder()
query = "冬天"
print(f"\n查询: {query}")
sparse_query = sparse_embedder.embed_query(query)
sparse_vec = models.SparseVector(
indices=sparse_query["indices"],
values=sparse_query["values"]
)
response = client.query_points(
collection_name="rag_documents",
query=sparse_vec,
using="sparse",
limit=3,
with_payload=True
)
print(f"\n找到 {len(response.points)} 个结果\n")
for i, point in enumerate(response.points):
print(f"--- 结果 {i+1} ---")
print(f"分数: {point.score:.4f}")
text = point.payload.get("page_content", point.payload.get("text", ""))
print(text[:200])
print()
def test_hybrid_retrieval():
"""测试混合检索"""
print("="*70)
print("测试混合检索(稠密+稀疏 RRF 融合)...")
print("="*70)
embeddings = get_embedding_service()
vs = QdrantHybridStore(collection_name="rag_documents", embeddings=embeddings)
client = vs.get_qdrant_client()
sparse_embedder = get_sparse_embedder()
query = "蚂蚁和蚱蜢"
print(f"\n查询: {query}")
dense_query = embeddings.embed_query(query)
sparse_query = sparse_embedder.embed_query(query)
sparse_vec = models.SparseVector(
indices=sparse_query["indices"],
values=sparse_query["values"]
)
response = client.query_points(
collection_name="rag_documents",
prefetch=[
models.Prefetch(query=dense_query, using="dense", limit=3),
models.Prefetch(query=sparse_vec, using="sparse", limit=3)
],
query=models.FusionQuery(fusion=models.Fusion.RRF),
limit=3,
with_payload=True
)
print(f"\n找到 {len(response.points)} 个结果\n")
for i, point in enumerate(response.points):
print(f"--- 结果 {i+1} ---")
print(f"分数: {point.score:.4f}")
text = point.payload.get("page_content", point.payload.get("text", ""))
print(text[:200])
print()
if __name__ == "__main__":
check_qdrant_content()
test_dense_retrieval()
test_sparse_retrieval()
test_hybrid_retrieval()

View File

@@ -1,303 +0,0 @@
#!/usr/bin/env python3
"""
完整后端测试 - 验证 Agent 所有功能
包括:短期记忆、长期记忆、工具调用、流式对话、历史查询
"""
import asyncio
import os
import sys
import uuid
from dotenv import load_dotenv
# 加载环境变量
project_root = os.path.join(os.path.dirname(__file__), "..")
load_dotenv(os.path.join(project_root, ".env"))
from backend.app.config import DB_URI
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from backend.app.agent.agent_service import AIAgentService
from backend.app.agent.history import ThreadHistoryService
from backend.app.logger import info, warning, error
# PostgreSQL 连接字符串
async def print_section(title):
"""打印测试区块标题"""
print("\n" + "=" * 70)
print(f" {title}")
print("=" * 70)
async def test_short_term_memory(agent_service):
"""测试短期记忆(同一 thread_id 继续对话)"""
await print_section("测试 1: 短期记忆Short-term Memory")
thread_id = str(uuid.uuid4())
user_id = "test_user_memory"
print(f"\n使用 thread_id: {thread_id[:8]}...")
print(f"使用 user_id: {user_id}")
# 第一轮对话
print("\n[第一轮] 发送消息: '我叫张三今年28岁'")
result1 = await agent_service.process_message(
"我叫张三今年28岁", thread_id, "local", user_id
)
print(f"回复: {result1['reply'][:100]}...")
# 第二轮对话 - 测试记忆
print("\n[第二轮] 发送消息: '我叫什么名字?今年多大?'")
result2 = await agent_service.process_message(
"我叫什么名字?今年多大?", thread_id, "local", user_id
)
print(f"回复: {result2['reply']}")
# 验证记忆是否存在
if "张三" in result2['reply'] or "28" in result2['reply']:
print("\n✅ 短期记忆测试通过!")
return True
else:
print("\n❌ 短期记忆测试失败!")
return False
async def test_tool_calling(agent_service):
"""测试工具调用RAG 搜索)"""
await print_section("测试 2: 工具调用Tool Calling")
thread_id = str(uuid.uuid4())
user_id = "test_user_tools"
print(f"\n使用 thread_id: {thread_id[:8]}...")
print(f"使用 user_id: {user_id}")
# 发送需要 RAG 搜索的问题
print("\n发送消息: '请告诉我,黄双银在魔王大陆的故事?'")
result = await agent_service.process_message(
"请告诉我,黄双银在魔王大陆的故事?", thread_id, "local", user_id
)
print(f"回复: {result['reply'][:200]}...")
# 检查是否调用了 RAG 工具(回复中会有黄双银相关内容)
if "黄双银" in result['reply']:
print("\n✅ 工具调用测试通过!")
return True
else:
print("\n⚠️ 工具调用测试结果不确定,需要手动验证")
return None
async def test_streaming(agent_service):
"""测试流式对话"""
await print_section("测试 3: 流式对话Streaming")
thread_id = str(uuid.uuid4())
user_id = "test_user_stream"
print(f"\n使用 thread_id: {thread_id[:8]}...")
print(f"使用 user_id: {user_id}")
print("\n发送消息: '用100字介绍一下AI人工智能' (流式)...")
print("流式输出: ", end="", flush=True)
full_reply = ""
chunk_count = 0
try:
async for chunk in agent_service.process_message_stream(
"用100字介绍一下AI人工智能", thread_id, "local", user_id
):
chunk_count += 1
if chunk.get("type") == "llm_token":
token = chunk.get("token", "")
print(token, end="", flush=True)
full_reply += token
elif chunk.get("type") == "state_update":
pass # 状态更新不显示
print(f"\n\n共收到 {chunk_count} 个 chunk")
print(f"完整回复长度: {len(full_reply)}")
if chunk_count > 0 and len(full_reply) > 10:
print("\n✅ 流式对话测试通过!")
return True
else:
print("\n❌ 流式对话测试失败!")
return False
except Exception as e:
print(f"\n❌ 流式对话异常: {e}")
return False
async def test_history_service(agent_service, history_service):
"""测试历史查询服务"""
await print_section("测试 4: 历史查询服务History Service")
user_id = "test_user_history"
# 先创建几个对话
print(f"\n为 user_id={user_id} 创建测试对话...")
thread_ids = []
for i in range(3):
thread_id = str(uuid.uuid4())
thread_ids.append(thread_id)
await agent_service.process_message(
f"这是第 {i+1} 个测试对话", thread_id, "local", user_id
)
print(f" 创建线程 {i+1}: {thread_id[:8]}...")
# 1. 测试获取用户线程列表
print("\n[4.1] 测试获取用户线程列表...")
threads = await history_service.get_user_threads(user_id, limit=10)
print(f" 找到 {len(threads)} 个线程")
if len(threads) >= 3:
print(" ✅ 线程列表查询通过")
else:
print(" ⚠️ 线程数量少于预期")
# 2. 测试获取单个线程的消息历史
if thread_ids:
test_thread_id = thread_ids[0]
print(f"\n[4.2] 测试获取线程消息历史 (thread_id={test_thread_id[:8]}...)")
messages = await history_service.get_thread_messages(test_thread_id)
print(f" 找到 {len(messages)} 条消息")
if len(messages) >= 2: # 至少有一问一答
print(" ✅ 消息历史查询通过")
else:
print(" ⚠️ 消息数量少于预期")
# 3. 测试获取线程摘要
print(f"\n[4.3] 测试获取线程摘要...")
summary = await history_service.get_thread_summary(test_thread_id)
print(f" 摘要: {summary.get('summary', '')[:50]}...")
print(f" 消息数: {summary.get('message_count', 0)}")
if summary.get('message_count', 0) > 0:
print(" ✅ 线程摘要查询通过")
else:
print(" ⚠️ 摘要查询结果不确定")
return len(threads) >= 3
async def test_long_term_memory(agent_service):
"""测试长期记忆mem0"""
await print_section("测试 5: 长期记忆Long-term Memory - mem0")
thread_id1 = str(uuid.uuid4())
thread_id2 = str(uuid.uuid4()) # 不同的线程
user_id = "test_user_longterm"
print(f"\n使用 user_id: {user_id}")
print(f"线程 1: {thread_id1[:8]}...")
print(f"线程 2: {thread_id2[:8]}...")
# 在第一个线程中保存信息
print("\n[线程 1] 发送消息: '记住,我的宠物名字叫小白,是一只猫'")
result1 = await agent_service.process_message(
"记住,我的宠物名字叫小白,是一只猫", thread_id1, "local", user_id
)
print(f"回复: {result1['reply'][:100]}...")
# 等待一下,让 mem0 保存
await asyncio.sleep(1)
# 在第二个线程中询问(不同的 thread_id
print("\n[线程 2] 发送消息: '我的宠物叫什么名字?是什么动物?'")
result2 = await agent_service.process_message(
"我的宠物叫什么名字?是什么动物?", thread_id2, "local", user_id
)
print(f"回复: {result2['reply']}")
# 验证长期记忆
if "小白" in result2['reply'] or "" in result2['reply']:
print("\n✅ 长期记忆测试通过!")
return True
else:
print("\n⚠️ 长期记忆可能未启用,或需要手动验证")
return None
async def main():
"""主测试函数"""
print("\n" + "=" * 70)
print(" 后端完整功能测试")
print("=" * 70)
results = {}
try:
# 创建数据库连接和服务
print("\n正在初始化数据库连接...")
async with AsyncPostgresSaver.from_conn_string(DB_URI) as checkpointer:
await checkpointer.setup()
print("✅ 数据库连接成功")
# 创建服务实例
print("\n正在初始化 Agent 服务...")
agent_service = AIAgentService(checkpointer)
await agent_service.initialize()
print("✅ Agent 服务初始化成功")
history_service = ThreadHistoryService(checkpointer)
print("✅ 历史服务初始化成功")
print(f"\n可用模型: {list(agent_service.graphs.keys())}")
# 运行测试
results["短期记忆"] = await test_short_term_memory(agent_service)
await asyncio.sleep(1)
results["工具调用"] = await test_tool_calling(agent_service)
await asyncio.sleep(1)
results["流式对话"] = await test_streaming(agent_service)
await asyncio.sleep(1)
results["历史查询"] = await test_history_service(agent_service, history_service)
await asyncio.sleep(1)
results["长期记忆"] = await test_long_term_memory(agent_service)
await asyncio.sleep(1)
# 打印总结
await print_section("测试总结")
print("\n测试结果:")
print("-" * 40)
pass_count = 0
fail_count = 0
skip_count = 0
for test_name, result in results.items():
if result is True:
status = "✅ 通过"
pass_count += 1
elif result is False:
status = "❌ 失败"
fail_count += 1
else:
status = "⚠️ 待验证"
skip_count += 1
print(f" {test_name:12s}: {status}")
print("-" * 40)
print(f"总计: {len(results)} 个测试")
print(f"通过: {pass_count}, 失败: {fail_count}, 待验证: {skip_count}")
if fail_count == 0:
print("\n🎉 所有核心测试通过!")
else:
print(f"\n⚠️ 有 {fail_count} 个测试失败")
except Exception as e:
error(f"\n❌ 测试运行异常: {e}")
import traceback
traceback.print_exc()
return 1
return 0 if fail_count == 0 else 1
if __name__ == "__main__":
exit_code = asyncio.run(main())
sys.exit(exit_code)

View File

@@ -1,63 +0,0 @@
"""检查 Qdrant 中存储的向量质量。"""
import os
import sys
import numpy as np
from dotenv import load_dotenv
from qdrant_client import QdrantClient
# 加载环境变量
project_root = os.path.join(os.path.dirname(__file__), "..")
load_dotenv(os.path.join(project_root, ".env"))
from backend.rag_core import LlamaCppEmbedder
QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333")
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
COLLECTION_NAME = "rag_documents"
client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)
embedder = LlamaCppEmbedder()
# 获取样本
points, _ = client.scroll(
collection_name=COLLECTION_NAME,
limit=1,
with_vectors=True,
with_payload=True,
)
if not points:
print(f"集合 '{COLLECTION_NAME}' 为空")
exit()
sample = points[0]
raw_vec = sample.vector
if isinstance(raw_vec, dict):
stored_vec = list(raw_vec.values())[0]
elif isinstance(raw_vec, list):
stored_vec = raw_vec
else:
stored_vec = []
stored_payload = sample.payload or {}
stored_text = str(stored_payload.get("page_content", ""))[:200]
print(f"内容预览:\n{stored_text}...\n")
print(f"向量维度: {len(stored_vec)}") # type: ignore
print(f"前5个值: {stored_vec[:5]}") # type: ignore
print(f"是否全零: {all(v == 0.0 for v in stored_vec)}") # type: ignore
# 重新编码对比
if stored_text:
new_vec = embedder.embed_query(stored_text)
similarity = np.dot(stored_vec, new_vec) / (np.linalg.norm(stored_vec) * np.linalg.norm(new_vec)) # type: ignore
print(f"\n重新编码前5个值: {new_vec[:5]}")
print(f"余弦相似度: {similarity:.4f}")
if similarity < 0.8:
print("\n⚠️ 相似度过低,建议删除集合并重建索引")
else:
print("\n✅ 向量一致")
else:
print("\n⚠️ 样本无文本内容")

View File

@@ -1,60 +0,0 @@
#!/usr/bin/env python3
"""
前端快速测试脚本
验证前端导入是否正常工作
"""
import sys
import os
print("=" * 60)
print("前端导入测试")
print("=" * 60)
# 测试 1: 直接导入前端模块
print("\n[测试 1] 直接导入前端模块...")
try:
from frontend.src.frontend_main import main
print("✅ frontend_main 导入成功")
except Exception as e:
print(f"❌ 导入失败: {e}")
sys.exit(1)
# 测试 2: 导入配置
print("\n[测试 2] 导入配置...")
try:
from frontend.src.config import config
print(f"✅ config 导入成功: page_title={config.page_title}")
except Exception as e:
print(f"❌ 导入失败: {e}")
# 测试 3: 导入状态管理
print("\n[测试 3] 导入状态管理...")
try:
from frontend.src.state import AppState
print("✅ AppState 导入成功")
except Exception as e:
print(f"❌ 导入失败: {e}")
# 测试 4: 导入 API 客户端
print("\n[测试 4] 导入 API 客户端...")
try:
from frontend.src.api_client import api_client
print("✅ api_client 导入成功")
except Exception as e:
print(f"❌ 导入失败: {e}")
# 测试 5: 导入组件
print("\n[测试 5] 导入组件...")
try:
from frontend.src.components.sidebar import render_sidebar
from frontend.src.components.chat_area import render_chat_area
from frontend.src.components.info_panel import render_info_panel
print("✅ 所有组件导入成功")
except Exception as e:
print(f"❌ 导入失败: {e}")
print("\n" + "=" * 60)
print("🎉 所有前端导入测试通过!")
print("=" * 60)
print("\n现在可以使用 ./scripts/start.sh both 启动完整服务")

View File

@@ -1,141 +0,0 @@
#!/usr/bin/env python3
"""
RAG 系统使用示例(重构版)
演示:
1. 使用 IndexBuilder 获取父子块检索器
2. 创建固定流程的 RAGPipeline多路改写 → RRF融合 → 重排序 → 返回父文档)
3. 将流水线封装为 LangChain 工具,供 Agent 调用
"""
import asyncio
import sys
import os
from dotenv import load_dotenv
# 加载环境变量Qdrant URL、PostgreSQL 连接等)
project_root = os.path.join(os.path.dirname(__file__), "..")
load_dotenv(os.path.join(project_root, ".env"))
from pydantic import SecretStr
from langchain_openai import ChatOpenAI
from rag_indexer.index_builder import IndexBuilderConfig
from rag_indexer.splitters import SplitterType
from backend.app.rag.pipeline import RAGPipeline
from backend.app.rag.tools import create_rag_tool_sync
from backend.rag_core.retriever_factory import create_parent_retriever
def create_llm():
"""创建本地 vLLM 服务 LLM"""
vllm_base_url = os.getenv(
"VLLM_BASE_URL",
"http://127.0.0.1:8081/v1"
)
return ChatOpenAI(
base_url=vllm_base_url,
api_key=SecretStr(os.getenv("LLAMACPP_API_KEY", "token-abc123")),
model="gemma-4-E2B-it",
timeout=60.0, # 请求超时时间(秒)
max_retries=2, # 失败后自动重试次数
streaming=True, # 确保开启流式输出
)
async def demonstrate_full_pipeline():
"""
完整流水线演示:
- 从 IndexBuilder 获取 ParentDocumentRetriever
- 创建 RAGPipeline
- 执行检索并打印结果
"""
print("=" * 60)
print("演示:固定流程 RAG 检索(多路改写 + RRF + 重排序 + 父文档)")
print("=" * 60)
retriever = create_parent_retriever(collection_name="rag_documents", search_k=5)
if retriever is None:
print("错误:检索器未初始化,请确保索引已构建。")
return
# 3. 创建 LLM 用于查询改写
llm = create_llm()
# 4. 创建 RAGPipeline固定流程
pipeline = RAGPipeline(
retriever=retriever,
llm=llm,
num_queries=3, # 生成 3 个查询变体
rerank_top_n=5, # 最终返回 5 个父文档
)
# 5. 执行检索
query = "打虎英雄是谁?"
print(f"\n查询: {query}")
print("-" * 40)
try:
documents = await pipeline.aretrieve(query)
print(f"返回 {len(documents)} 个父文档\n")
# 打印结果预览
for i, doc in enumerate(documents, 1):
content_preview = doc.page_content.replace("\n", " ")[:150]
source = doc.metadata.get("source", "未知来源")
print(f"{i}. 【来源:{source}")
print(f" {content_preview}...\n")
# 可选:格式化完整上下文
# context = pipeline.format_context(documents)
# print(context)
except Exception as e:
print(f"检索失败: {e}")
import traceback
traceback.print_exc()
async def demonstrate_tool_creation():
"""
演示创建 RAG 工具(供 Agent 使用)
"""
print("\n" + "=" * 60)
print("演示:创建 RAG 工具(供 LangGraph Agent 调用)")
print("=" * 60)
# 1. 获取检索器(同上)
config = IndexBuilderConfig(
collection_name="rag_documents",
splitter_type=SplitterType.PARENT_CHILD,
)
retriever = create_parent_retriever(collection_name="rag_documents", search_k=5)
# 2. 创建 LLM
llm = create_llm()
# 3. 创建工具
rag_tool = create_rag_tool_sync(
retriever=retriever,
llm=llm,
num_queries=3,
rerank_top_n=5,
collection_name="rag_documents",
)
print(f"工具名称: {rag_tool.name}")
print(f"工具描述: {rag_tool.description[:100]}...")
# 4. 模拟 Agent 调用工具
query = "请告诉我 打虎英雄是谁?"
print(f"\n模拟调用: {query}")
print("-" * 40)
result = await rag_tool.ainvoke({"query": query})
print(result[:800] + "..." if len(result) > 800 else result)
async def main():
await demonstrate_full_pipeline()
await demonstrate_tool_creation()
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,305 +1,130 @@
#!/usr/bin/env python3
"""
测试重构后的 IndexBuilder 和 RAG 检索
包括:索引构建、稠密检索、稀疏检索、混合检索、父子文档检索
简单的 RAG 检索测试
使用 app/rag/retriever 提供的功能
"""
import asyncio
import os
from rag_indexer.index_builder import IndexBuilder
from rag_indexer.splitters import SplitterType
from backend.rag_core import QdrantHybridStore, get_sparse_embedder
from backend.app.model_services import get_embedding_service
from qdrant_client import models
from backend.app.rag.retriever import (
create_parent_hybrid_retriever,
create_hybrid_retriever
)
from backend.rag_core import QdrantHybridStore
async def test_index_builder():
"""测试索引构建功能"""
print("="*70)
print("1. 测试索引构建功能...")
print("="*70)
# 统一的测试查询列表
TEST_QUERIES = [
"黄双银",
]
async def test_simple_vector_store_search():
"""测试:直接使用 QdrantHybridStore 的 asimilarity_search"""
print("="*80)
print("测试 1: QdrantHybridStore.asimilarity_search")
print("="*80)
# 创建 IndexBuilder 实例
builder = IndexBuilder(
vs = QdrantHybridStore(collection_name="rag_documents")
for query in TEST_QUERIES:
print(f"\n查询: {query}")
print("-" * 60)
docs = await vs.asimilarity_search(query, k=10)
if docs:
print(f"✓ 找到 {len(docs)} 个文档")
for i, doc in enumerate(docs, 1):
print(f"\n {i}. 来源: {doc.metadata.get('source', 'unknown')}")
preview = doc.page_content[:120].strip()
if len(doc.page_content) > 120:
preview += "..."
print(f" 内容: {preview}")
else:
print("✗ 未找到结果")
await vs.close_async_client()
print("\n" + "="*80)
async def test_hybrid_retriever():
"""测试HybridRetriever子文档检索"""
print("\n" + "="*80)
print("测试 2: HybridRetriever (子文档混合检索)")
print("="*80)
retriever = create_hybrid_retriever(
collection_name="rag_documents",
splitter_type=SplitterType.PARENT_CHILD,
parent_chunk_size=1000,
child_chunk_size=200
search_k=10
)
# 测试文档路径
project_root = os.path.join(os.path.dirname(__file__), "..", "..")
test_file = os.path.join(project_root, "data", "user_docs", "doublestory.txt")
if os.path.exists(test_file):
# 构建索引
print(f"正在为文件 {test_file} 构建索引...")
processed = await builder.build_from_file(test_file)
print(f"索引构建完成,处理了 {processed} 个文档")
for query in TEST_QUERIES:
print(f"\n查询: {query}")
print("-" * 60)
# 获取集合信息
info = builder.get_collection_info()
print(f"集合信息: {info}")
else:
print(f"测试文件不存在: {test_file}")
docs = await retriever.ainvoke(query)
if docs:
print(f"✓ 找到 {len(docs)} 个子文档")
for i, doc in enumerate(docs, 1):
print(f"\n {i}. parent_id: {doc.metadata.get('parent_id', 'none')}")
preview = doc.page_content[:100].strip()
if len(doc.page_content) > 100:
preview += "..."
print(f" 内容: {preview}")
else:
print("✗ 未找到结果")
# 关闭资源
builder.close()
print("\n索引构建测试完成")
return processed
print("\n" + "="*80)
def test_dense_retrieval():
"""测试稠密检索"""
print("\n" + "="*70)
print("2. 测试稠密检索...")
print("="*70)
async def test_parent_hybrid_retriever():
"""测试ParentHybridRetriever父子文档混合检索"""
print("\n" + "="*80)
print("测试 3: ParentHybridRetriever (父子文档混合检索)")
print("="*80)
# 获取嵌入服务
embeddings = get_embedding_service()
# 创建向量存储
vs = QdrantHybridStore(collection_name="rag_documents", embeddings=embeddings)
# 测试查询
query = "The Ant and the Grasshopper"
print(f"查询: {query}")
results = vs.similarity_search(query, k=3)
print(f"\n找到 {len(results)} 个结果:")
for i, doc in enumerate(results, 1):
print(f"\n{i}. (来源: {doc.metadata.get('source', 'unknown')})")
print(f" 元数据: {doc.metadata}")
content = doc.page_content.strip()
if len(content) > 200:
content = content[:200] + "..."
print(f" 内容: {content}")
def test_sparse_retrieval_simple():
"""简单测试稀疏检索"""
print("\n" + "="*70)
print("3. 测试稀疏检索BM25...")
print("="*70)
# 获取嵌入服务和稀疏嵌入器
embeddings = get_embedding_service()
vs = QdrantHybridStore(collection_name="rag_documents", embeddings=embeddings)
client = vs.get_qdrant_client()
sparse_embedder = get_sparse_embedder()
# 测试查询 - 用关键词
query = "winter work food"
print(f"查询关键词: {query}")
# 生成稀疏查询向量
sparse_query = sparse_embedder.embed_query(query)
# 包装成 SparseVector 对象
sparse_vec = models.SparseVector(
indices=sparse_query["indices"],
values=sparse_query["values"]
)
# 直接查询稀疏向量
response = client.query_points(
retriever = create_parent_hybrid_retriever(
collection_name="rag_documents",
query=sparse_vec,
using="sparse",
limit=3,
with_payload=True
search_k=10
)
print(f"\n找到 {len(response.points)} 个结果:")
for i, point in enumerate(response.points, 1):
print(f"\n{i}. (分数: {point.score:.4f})")
text = point.payload.get("text", "")
metadata = {k: v for k, v in point.payload.items() if k != "text"}
print(f" 元数据: {metadata}")
content = text.strip()
if len(content) > 200:
content = content[:200] + "..."
print(f" 内容: {content}")
def test_hybrid_retrieval_simple():
"""简单测试混合检索(稠密+稀疏 RRF 融合)"""
print("\n" + "="*70)
print("4. 测试混合检索(稠密+稀疏 RRF 融合)...")
print("="*70)
# 获取嵌入服务和稀疏嵌入器
embeddings = get_embedding_service()
vs = QdrantHybridStore(collection_name="rag_documents", embeddings=embeddings)
client = vs.get_qdrant_client()
sparse_embedder = get_sparse_embedder()
# 测试查询
query = "Ant and Grasshopper story"
print(f"查询: {query}")
# 生成双向量
dense_query = embeddings.embed_query(query)
sparse_query = sparse_embedder.embed_query(query)
sparse_vec = models.SparseVector(
indices=sparse_query["indices"],
values=sparse_query["values"]
)
# 使用 Qdrant 的 query_points 做混合检索
response = client.query_points(
collection_name="rag_documents",
prefetch=[
models.Prefetch(
query=dense_query,
using="dense",
limit=3
),
models.Prefetch(
query=sparse_vec,
using="sparse",
limit=3
)
],
query=models.FusionQuery(fusion=models.Fusion.RRF),
limit=3,
with_payload=True
)
print(f"\n找到 {len(response.points)} 个结果:")
for i, point in enumerate(response.points, 1):
print(f"\n{i}. (RRF 融合分数: {point.score:.4f})")
text = point.payload.get("text", "")
metadata = {k: v for k, v in point.payload.items() if k != "text"}
print(f" 元数据: {metadata}")
content = text.strip()
if len(content) > 200:
content = content[:200] + "..."
print(f" 内容: {content}")
def test_parent_child_retrieval_simple():
"""简单测试父子文档检索"""
print("\n" + "="*70)
print("5. 测试父子文档混合检索...")
print("="*70)
# 获取嵌入服务和稀疏嵌入器
embeddings = get_embedding_service()
vs = QdrantHybridStore(collection_name="rag_documents", embeddings=embeddings)
client = vs.get_qdrant_client()
sparse_embedder = get_sparse_embedder()
# 测试查询
query = "The Ant and the Grasshopper story moral"
print(f"查询: {query}")
# 生成双向量
dense_query = embeddings.embed_query(query)
sparse_query = sparse_embedder.embed_query(query)
sparse_vec = models.SparseVector(
indices=sparse_query["indices"],
values=sparse_query["values"]
)
# 先做混合检索找到子文档
response = client.query_points(
collection_name="rag_documents",
prefetch=[
models.Prefetch(
query=dense_query,
using="dense",
limit=5
),
models.Prefetch(
query=sparse_vec,
using="sparse",
limit=5
)
],
query=models.FusionQuery(fusion=models.Fusion.RRF),
limit=5,
with_payload=True
)
# 收集 parent_id
parent_score_map = {}
child_points = {}
for point in response.points:
parent_id = point.payload.get("parent_id", point.id)
score = point.score
if parent_id not in parent_score_map or score > parent_score_map[parent_id]:
parent_score_map[parent_id] = score
child_points[parent_id] = point
parent_ids = list(parent_score_map.keys())
print(f"\n找到 {len(parent_ids)} 个不同的 parent_id:")
# 查找父文档
if parent_ids:
parent_docs = client.retrieve(
collection_name="rag_documents",
ids=parent_ids,
with_payload=True
)
for query in TEST_QUERIES:
print(f"\n查询: {query}")
print("-" * 60)
found_parent_ids = {p.id for p in parent_docs}
docs = await retriever.ainvoke(query)
# 准备结果列表
results = []
for p in parent_docs:
score = parent_score_map[p.id]
results.append((p, score))
# 处理没找到父文档的情况 - 用子文档代替
missing = set(parent_ids) - found_parent_ids
for parent_id in missing:
child_point = child_points[parent_id]
print(f"\n注意: parent_id {parent_id} 未找到,使用子文档代替")
results.append((child_point, parent_score_map[parent_id]))
# 按分数排序
results.sort(key=lambda x: x[1], reverse=True)
# 显示
print(f"\n{len(results)} 个结果(去重后):")
for i, (point, score) in enumerate(results[:3], 1):
print(f"\n{i}. (分数: {score:.4f})")
text = point.payload.get("text", "")
metadata = {k: v for k, v in point.payload.items() if k != "text"}
print(f" 元数据: {metadata}")
content = text.strip()
if len(content) > 400:
content = content[:400] + "..."
print(f" 内容: {content}")
else:
print("\n未找到结果")
if docs:
print(f"✓ 找到 {len(docs)} 个父文档")
for i, doc in enumerate(docs, 1):
print(f"\n {i}. 来源: {doc.metadata.get('source', 'unknown')}")
preview = doc.page_content[:150].strip()
if len(doc.page_content) > 150:
preview += "..."
print(f" 内容:\n {preview}")
else:
print("✗ 未找到结果")
print("\n" + "="*80)
async def main():
"""主测试函数"""
# 1. 先构建索引
await test_index_builder()
print("\n" + "="*80)
print("RAG 检索功能测试")
print("="*80)
# 2. 测试稠密检索
test_dense_retrieval()
# 测试 1: 直接使用 vector store
await test_simple_vector_store_search()
# 3. 测试稀疏检索
test_sparse_retrieval_simple()
# 测试 2: HybridRetriever
await test_hybrid_retriever()
# 4. 测试混合检索
test_hybrid_retrieval_simple()
# 测试 3: ParentHybridRetriever
await test_parent_hybrid_retriever()
# 5. 测试父子文档检索
test_parent_child_retrieval_simple()
print("\n" + "="*70)
print("所有测试完成!")
print("="*70)
print("\n🎉 所有测试完成!")
if __name__ == "__main__":

View File

@@ -0,0 +1,145 @@
#!/usr/bin/env python3
"""
完整的 RAG Pipeline 测试
测试从查询改写 → 检索 → RRF融合 → 重排序 → 格式化输出的整个流程
"""
import asyncio
from backend.app.rag.pipeline import RAGPipeline, create_rag_pipeline
from backend.app.rag.tools import create_rag_tool
async def test_rag_pipeline_direct():
"""测试 1: 直接使用 RAGPipeline默认用小模型做查询改写"""
print("="*80)
print("测试 1: 直接使用 RAGPipeline默认用小模型做查询改写")
print("="*80)
# 创建 pipeline默认用小模型
pipeline = create_rag_pipeline(
collection_name="rag_documents",
num_queries=3,
rerank_top_n=5
)
query = "黄双银的经历"
print(f"\n用户查询: {query}")
print("-" * 80)
# 执行检索
docs = await pipeline.aretrieve(query)
if docs:
print(f"\n✓ 找到 {len(docs)} 个相关文档")
print("-" * 80)
for i, doc in enumerate(docs, 1):
print(f"\n{i}. 来源: {doc.metadata.get('source', 'unknown')}")
print(f" 内容:\n{doc.page_content}")
print("-" * 80)
# 格式化输出
print("\n" + "="*80)
print("格式化后的上下文:")
print("="*80)
formatted_context = pipeline.format_context(docs)
print(formatted_context)
else:
print("\n✗ 未找到相关文档")
print("\n" + "="*80)
async def test_rag_tool():
"""测试 2: 使用 RAG Tool默认用小模型做查询改写"""
print("\n"+"="*80)
print("测试 2: 使用 RAG Tool默认用小模型做查询改写")
print("="*80)
# 创建 tool默认用小模型
rag_tool = create_rag_tool(
collection_name="rag_documents",
num_queries=3,
rerank_top_n=5
)
query = "黄双银的经历"
print(f"\n用户查询: {query}")
print("-" * 80)
# 使用 tool (异步调用 ainvoke)
result = await rag_tool.ainvoke(query)
print("\nTool 返回结果:")
print("="*80)
print(result)
print("="*80)
async def test_custom_pipeline():
"""测试 3: 自定义参数的 RAGPipeline默认用小模型"""
print("\n"+"="*80)
print("测试 3: 自定义参数的 RAGPipeline默认用小模型")
print("="*80)
# 自定义参数(默认用小模型)
pipeline = RAGPipeline(
collection_name="rag_documents",
num_queries=2, # 只生成 2 个查询变体
rerank_top_n=3 # 只返回前 3 个最相关文档
)
query = "黄双银的经历"
print(f"\n用户查询: {query}")
print(f"配置: num_queries=2, rerank_top_n=3")
print("-" * 80)
docs = await pipeline.aretrieve(query)
if docs:
print(f"\n✓ 找到 {len(docs)} 个相关文档")
print("-" * 80)
for i, doc in enumerate(docs, 1):
print(f"\n{i}. 来源: {doc.metadata.get('source', 'unknown')}")
preview = doc.page_content[:200].strip()
if len(doc.page_content) > 200:
preview += "..."
print(f" 内容预览: {preview}")
print("\n" + "="*80)
print("格式化后的上下文:")
print("="*80)
print(pipeline.format_context(docs))
else:
print("\n✗ 未找到相关文档")
print("\n" + "="*80)
async def main():
"""主测试函数"""
print("\n" + "="*80)
print("完整 RAG Pipeline 测试")
print("查询: '黄双银的经历'")
print("="*80)
# 测试 1: 直接使用 pipeline
await test_rag_pipeline_direct()
# 测试 2: 使用 tool
await test_rag_tool()
# 测试 3: 自定义参数
await test_custom_pipeline()
print("\n" + "="*80)
print("🎉 所有 RAG Pipeline 测试完成!")
print("="*80)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,51 +0,0 @@
#!/usr/bin/env python3
"""
测试 app/rag/retriever.py 里的混合检索函数
"""
import asyncio
import os
import sys
from backend.app.rag.retriever import create_hybrid_retriever, create_parent_hybrid_retriever
def test_hybrid_retriever():
"""测试混合检索器"""
print("="*70)
print("测试 HybridRetriever...")
print("="*70)
retriever = create_hybrid_retriever(collection_name="rag_documents", search_k=3)
results = retriever.invoke("黄双银")
print(f"\n找到 {len(results)} 个结果\n")
for i, doc in enumerate(results):
print(f"--- 结果 {i+1} ---")
print(doc.page_content[:200])
print()
def test_parent_hybrid_retriever():
"""测试父子混合检索器"""
print("\n" + "="*70)
print("测试 ParentHybridRetriever...")
print("="*70)
retriever = create_parent_hybrid_retriever(
collection_name="rag_documents",
search_k=3,
use_docstore=False
)
results = retriever.invoke("黄双银")
print(f"\n找到 {len(results)} 个结果\n")
for i, doc in enumerate(results):
print(f"--- 结果 {i+1} ---")
print(doc.page_content[:300])
print()
if __name__ == "__main__":
test_hybrid_retriever()
test_parent_hybrid_retriever()