容器处理
This commit is contained in:
306
test/test_backend.py
Normal file
306
test/test_backend.py
Normal file
@@ -0,0 +1,306 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
完整后端测试 - 验证 Agent 所有功能
|
||||
包括:短期记忆、长期记忆、工具调用、流式对话、历史查询
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 添加项目根目录和 backend 目录到 Python 路径
|
||||
project_root = os.path.join(os.path.dirname(__file__), "..")
|
||||
backend_dir = os.path.join(project_root, "backend")
|
||||
sys.path.insert(0, project_root)
|
||||
sys.path.insert(0, backend_dir)
|
||||
load_dotenv()
|
||||
|
||||
from backend.app.config import DB_URI
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
from backend.app.agent.service import AIAgentService
|
||||
from backend.app.agent.history import ThreadHistoryService
|
||||
from backend.app.logger import info, warning, error
|
||||
|
||||
# PostgreSQL 连接字符串
|
||||
|
||||
async def print_section(title):
|
||||
"""打印测试区块标题"""
|
||||
print("\n" + "=" * 70)
|
||||
print(f" {title}")
|
||||
print("=" * 70)
|
||||
|
||||
async def test_short_term_memory(agent_service):
|
||||
"""测试短期记忆(同一 thread_id 继续对话)"""
|
||||
await print_section("测试 1: 短期记忆(Short-term Memory)")
|
||||
|
||||
thread_id = str(uuid.uuid4())
|
||||
user_id = "test_user_memory"
|
||||
|
||||
print(f"\n使用 thread_id: {thread_id[:8]}...")
|
||||
print(f"使用 user_id: {user_id}")
|
||||
|
||||
# 第一轮对话
|
||||
print("\n[第一轮] 发送消息: '我叫张三,今年28岁'")
|
||||
result1 = await agent_service.process_message(
|
||||
"我叫张三,今年28岁", thread_id, "local", user_id
|
||||
)
|
||||
print(f"回复: {result1['reply'][:100]}...")
|
||||
|
||||
# 第二轮对话 - 测试记忆
|
||||
print("\n[第二轮] 发送消息: '我叫什么名字?今年多大?'")
|
||||
result2 = await agent_service.process_message(
|
||||
"我叫什么名字?今年多大?", thread_id, "local", user_id
|
||||
)
|
||||
print(f"回复: {result2['reply']}")
|
||||
|
||||
# 验证记忆是否存在
|
||||
if "张三" in result2['reply'] or "28" in result2['reply']:
|
||||
print("\n✅ 短期记忆测试通过!")
|
||||
return True
|
||||
else:
|
||||
print("\n❌ 短期记忆测试失败!")
|
||||
return False
|
||||
|
||||
async def test_tool_calling(agent_service):
|
||||
"""测试工具调用(RAG 搜索)"""
|
||||
await print_section("测试 2: 工具调用(Tool Calling)")
|
||||
|
||||
thread_id = str(uuid.uuid4())
|
||||
user_id = "test_user_tools"
|
||||
|
||||
print(f"\n使用 thread_id: {thread_id[:8]}...")
|
||||
print(f"使用 user_id: {user_id}")
|
||||
|
||||
# 发送需要 RAG 搜索的问题
|
||||
print("\n发送消息: '请告诉我,打虎英雄是谁?'")
|
||||
result = await agent_service.process_message(
|
||||
"请告诉我,打虎英雄是谁?", thread_id, "local", user_id
|
||||
)
|
||||
print(f"回复: {result['reply'][:200]}...")
|
||||
|
||||
# 检查是否调用了 RAG 工具(回复中会有水浒传相关内容)
|
||||
if "武松" in result['reply'] or "李忠" in result['reply'] or "水浒传" in result['reply']:
|
||||
print("\n✅ 工具调用测试通过!")
|
||||
return True
|
||||
else:
|
||||
print("\n⚠️ 工具调用测试结果不确定,需要手动验证")
|
||||
return None
|
||||
|
||||
async def test_streaming(agent_service):
|
||||
"""测试流式对话"""
|
||||
await print_section("测试 3: 流式对话(Streaming)")
|
||||
|
||||
thread_id = str(uuid.uuid4())
|
||||
user_id = "test_user_stream"
|
||||
|
||||
print(f"\n使用 thread_id: {thread_id[:8]}...")
|
||||
print(f"使用 user_id: {user_id}")
|
||||
|
||||
print("\n发送消息: '用100字介绍一下AI人工智能' (流式)...")
|
||||
print("流式输出: ", end="", flush=True)
|
||||
|
||||
full_reply = ""
|
||||
chunk_count = 0
|
||||
|
||||
try:
|
||||
async for chunk in agent_service.process_message_stream(
|
||||
"用100字介绍一下AI人工智能", thread_id, "local", user_id
|
||||
):
|
||||
chunk_count += 1
|
||||
if chunk.get("type") == "llm_token":
|
||||
token = chunk.get("token", "")
|
||||
print(token, end="", flush=True)
|
||||
full_reply += token
|
||||
elif chunk.get("type") == "state_update":
|
||||
pass # 状态更新不显示
|
||||
|
||||
print(f"\n\n共收到 {chunk_count} 个 chunk")
|
||||
print(f"完整回复长度: {len(full_reply)} 字")
|
||||
|
||||
if chunk_count > 0 and len(full_reply) > 10:
|
||||
print("\n✅ 流式对话测试通过!")
|
||||
return True
|
||||
else:
|
||||
print("\n❌ 流式对话测试失败!")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n❌ 流式对话异常: {e}")
|
||||
return False
|
||||
|
||||
async def test_history_service(agent_service, history_service):
|
||||
"""测试历史查询服务"""
|
||||
await print_section("测试 4: 历史查询服务(History Service)")
|
||||
|
||||
user_id = "test_user_history"
|
||||
|
||||
# 先创建几个对话
|
||||
print(f"\n为 user_id={user_id} 创建测试对话...")
|
||||
|
||||
thread_ids = []
|
||||
for i in range(3):
|
||||
thread_id = str(uuid.uuid4())
|
||||
thread_ids.append(thread_id)
|
||||
|
||||
await agent_service.process_message(
|
||||
f"这是第 {i+1} 个测试对话", thread_id, "local", user_id
|
||||
)
|
||||
print(f" 创建线程 {i+1}: {thread_id[:8]}...")
|
||||
|
||||
# 1. 测试获取用户线程列表
|
||||
print("\n[4.1] 测试获取用户线程列表...")
|
||||
threads = await history_service.get_user_threads(user_id, limit=10)
|
||||
print(f" 找到 {len(threads)} 个线程")
|
||||
|
||||
if len(threads) >= 3:
|
||||
print(" ✅ 线程列表查询通过")
|
||||
else:
|
||||
print(" ⚠️ 线程数量少于预期")
|
||||
|
||||
# 2. 测试获取单个线程的消息历史
|
||||
if thread_ids:
|
||||
test_thread_id = thread_ids[0]
|
||||
print(f"\n[4.2] 测试获取线程消息历史 (thread_id={test_thread_id[:8]}...)")
|
||||
messages = await history_service.get_thread_messages(test_thread_id)
|
||||
print(f" 找到 {len(messages)} 条消息")
|
||||
|
||||
if len(messages) >= 2: # 至少有一问一答
|
||||
print(" ✅ 消息历史查询通过")
|
||||
else:
|
||||
print(" ⚠️ 消息数量少于预期")
|
||||
|
||||
# 3. 测试获取线程摘要
|
||||
print(f"\n[4.3] 测试获取线程摘要...")
|
||||
summary = await history_service.get_thread_summary(test_thread_id)
|
||||
print(f" 摘要: {summary.get('summary', '')[:50]}...")
|
||||
print(f" 消息数: {summary.get('message_count', 0)}")
|
||||
|
||||
if summary.get('message_count', 0) > 0:
|
||||
print(" ✅ 线程摘要查询通过")
|
||||
else:
|
||||
print(" ⚠️ 摘要查询结果不确定")
|
||||
|
||||
return len(threads) >= 3
|
||||
|
||||
async def test_long_term_memory(agent_service):
|
||||
"""测试长期记忆(mem0)"""
|
||||
await print_section("测试 5: 长期记忆(Long-term Memory - mem0)")
|
||||
|
||||
thread_id1 = str(uuid.uuid4())
|
||||
thread_id2 = str(uuid.uuid4()) # 不同的线程
|
||||
user_id = "test_user_longterm"
|
||||
|
||||
print(f"\n使用 user_id: {user_id}")
|
||||
print(f"线程 1: {thread_id1[:8]}...")
|
||||
print(f"线程 2: {thread_id2[:8]}...")
|
||||
|
||||
# 在第一个线程中保存信息
|
||||
print("\n[线程 1] 发送消息: '记住,我的宠物名字叫小白,是一只猫'")
|
||||
result1 = await agent_service.process_message(
|
||||
"记住,我的宠物名字叫小白,是一只猫", thread_id1, "local", user_id
|
||||
)
|
||||
print(f"回复: {result1['reply'][:100]}...")
|
||||
|
||||
# 等待一下,让 mem0 保存
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# 在第二个线程中询问(不同的 thread_id)
|
||||
print("\n[线程 2] 发送消息: '我的宠物叫什么名字?是什么动物?'")
|
||||
result2 = await agent_service.process_message(
|
||||
"我的宠物叫什么名字?是什么动物?", thread_id2, "local", user_id
|
||||
)
|
||||
print(f"回复: {result2['reply']}")
|
||||
|
||||
# 验证长期记忆
|
||||
if "小白" in result2['reply'] or "猫" in result2['reply']:
|
||||
print("\n✅ 长期记忆测试通过!")
|
||||
return True
|
||||
else:
|
||||
print("\n⚠️ 长期记忆可能未启用,或需要手动验证")
|
||||
return None
|
||||
|
||||
async def main():
|
||||
"""主测试函数"""
|
||||
print("\n" + "=" * 70)
|
||||
print(" 后端完整功能测试")
|
||||
print("=" * 70)
|
||||
|
||||
results = {}
|
||||
|
||||
try:
|
||||
# 创建数据库连接和服务
|
||||
print("\n正在初始化数据库连接...")
|
||||
async with AsyncPostgresSaver.from_conn_string(DB_URI) as checkpointer:
|
||||
await checkpointer.setup()
|
||||
print("✅ 数据库连接成功")
|
||||
|
||||
# 创建服务实例
|
||||
print("\n正在初始化 Agent 服务...")
|
||||
agent_service = AIAgentService(checkpointer)
|
||||
await agent_service.initialize()
|
||||
print("✅ Agent 服务初始化成功")
|
||||
|
||||
history_service = ThreadHistoryService(checkpointer)
|
||||
print("✅ 历史服务初始化成功")
|
||||
|
||||
print(f"\n可用模型: {list(agent_service.graphs.keys())}")
|
||||
|
||||
# 运行测试
|
||||
results["短期记忆"] = await test_short_term_memory(agent_service)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
results["工具调用"] = await test_tool_calling(agent_service)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
results["流式对话"] = await test_streaming(agent_service)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
results["历史查询"] = await test_history_service(agent_service, history_service)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
results["长期记忆"] = await test_long_term_memory(agent_service)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# 打印总结
|
||||
await print_section("测试总结")
|
||||
print("\n测试结果:")
|
||||
print("-" * 40)
|
||||
|
||||
pass_count = 0
|
||||
fail_count = 0
|
||||
skip_count = 0
|
||||
|
||||
for test_name, result in results.items():
|
||||
if result is True:
|
||||
status = "✅ 通过"
|
||||
pass_count += 1
|
||||
elif result is False:
|
||||
status = "❌ 失败"
|
||||
fail_count += 1
|
||||
else:
|
||||
status = "⚠️ 待验证"
|
||||
skip_count += 1
|
||||
print(f" {test_name:12s}: {status}")
|
||||
|
||||
print("-" * 40)
|
||||
print(f"总计: {len(results)} 个测试")
|
||||
print(f"通过: {pass_count}, 失败: {fail_count}, 待验证: {skip_count}")
|
||||
|
||||
if fail_count == 0:
|
||||
print("\n🎉 所有核心测试通过!")
|
||||
else:
|
||||
print(f"\n⚠️ 有 {fail_count} 个测试失败")
|
||||
|
||||
except Exception as e:
|
||||
error(f"\n❌ 测试运行异常: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return 1
|
||||
|
||||
return 0 if fail_count == 0 else 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit_code = asyncio.run(main())
|
||||
sys.exit(exit_code)
|
||||
66
test/test_dqrant.py
Normal file
66
test/test_dqrant.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""检查 Qdrant 中存储的向量质量。"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
from dotenv import load_dotenv
|
||||
from qdrant_client import QdrantClient
|
||||
|
||||
# 添加项目根目录和 backend 目录到 Python 路径
|
||||
project_root = os.path.join(os.path.dirname(__file__), "..")
|
||||
backend_dir = os.path.join(project_root, "backend")
|
||||
sys.path.insert(0, project_root)
|
||||
sys.path.insert(0, backend_dir)
|
||||
load_dotenv()
|
||||
|
||||
from rag_core import LlamaCppEmbedder
|
||||
|
||||
QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333")
|
||||
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
|
||||
COLLECTION_NAME = "rag_documents"
|
||||
|
||||
client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)
|
||||
embedder = LlamaCppEmbedder()
|
||||
|
||||
# 获取样本
|
||||
points, _ = client.scroll(
|
||||
collection_name=COLLECTION_NAME,
|
||||
limit=1,
|
||||
with_vectors=True,
|
||||
with_payload=True,
|
||||
)
|
||||
|
||||
if not points:
|
||||
print(f"集合 '{COLLECTION_NAME}' 为空")
|
||||
exit()
|
||||
|
||||
sample = points[0]
|
||||
raw_vec = sample.vector
|
||||
if isinstance(raw_vec, dict):
|
||||
stored_vec = list(raw_vec.values())[0]
|
||||
elif isinstance(raw_vec, list):
|
||||
stored_vec = raw_vec
|
||||
else:
|
||||
stored_vec = []
|
||||
|
||||
stored_payload = sample.payload or {}
|
||||
stored_text = str(stored_payload.get("page_content", ""))[:200]
|
||||
|
||||
print(f"内容预览:\n{stored_text}...\n")
|
||||
print(f"向量维度: {len(stored_vec)}") # type: ignore
|
||||
print(f"前5个值: {stored_vec[:5]}") # type: ignore
|
||||
print(f"是否全零: {all(v == 0.0 for v in stored_vec)}") # type: ignore
|
||||
|
||||
# 重新编码对比
|
||||
if stored_text:
|
||||
new_vec = embedder.embed_query(stored_text)
|
||||
similarity = np.dot(stored_vec, new_vec) / (np.linalg.norm(stored_vec) * np.linalg.norm(new_vec)) # type: ignore
|
||||
print(f"\n重新编码前5个值: {new_vec[:5]}")
|
||||
print(f"余弦相似度: {similarity:.4f}")
|
||||
|
||||
if similarity < 0.8:
|
||||
print("\n⚠️ 相似度过低,建议删除集合并重建索引")
|
||||
else:
|
||||
print("\n✅ 向量一致")
|
||||
else:
|
||||
print("\n⚠️ 样本无文本内容")
|
||||
84
test/test_rag_indexer_result.py
Normal file
84
test/test_rag_indexer_result.py
Normal file
@@ -0,0 +1,84 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
测试重构后的 IndexBuilder 和 RAGRetriever
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
|
||||
# 添加项目根目录到 Python 路径
|
||||
project_root = os.path.join(os.path.dirname(__file__), "..")
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
from rag_indexer.index_builder import IndexBuilder
|
||||
from rag_indexer.splitters import SplitterType
|
||||
|
||||
async def test_index_builder():
|
||||
"""测试索引构建功能"""
|
||||
print("测试索引构建功能...")
|
||||
|
||||
# 创建 IndexBuilder 实例
|
||||
builder = IndexBuilder(
|
||||
collection_name="test_collection",
|
||||
splitter_type=SplitterType.PARENT_CHILD,
|
||||
parent_chunk_size=1000,
|
||||
child_chunk_size=200
|
||||
)
|
||||
|
||||
# 测试文档路径
|
||||
test_file = os.path.join(os.path.dirname(__file__), "..", "data", "user_docs", "a.txt")
|
||||
|
||||
if os.path.exists(test_file):
|
||||
# 构建索引
|
||||
print(f"正在为文件 {test_file} 构建索引...")
|
||||
processed = await builder.build_from_file(test_file)
|
||||
print(f"索引构建完成,处理了 {processed} 个文档")
|
||||
|
||||
# 获取集合信息
|
||||
info = builder.get_collection_info()
|
||||
print(f"集合信息: {info}")
|
||||
else:
|
||||
print(f"测试文件不存在: {test_file}")
|
||||
|
||||
# 测试搜索功能
|
||||
print("\n测试搜索功能...")
|
||||
try:
|
||||
results = builder.search("吕布", k=3)
|
||||
print(f"搜索结果数量: {len(results)}")
|
||||
for i, result in enumerate(results):
|
||||
print(f"\n结果 {i+1}:")
|
||||
print(f"内容: {result.page_content[:100]}...")
|
||||
except Exception as e:
|
||||
print(f"搜索测试失败: {e}")
|
||||
|
||||
# 测试带父块上下文的搜索
|
||||
print("\n测试带父块上下文的搜索...")
|
||||
try:
|
||||
results = await builder.search_with_parent_context("吕布", k=3)
|
||||
print(f"搜索结果数量: {len(results)}")
|
||||
for i, result in enumerate(results):
|
||||
print(f"\n结果 {i+1}:")
|
||||
print(f"内容: {result.page_content[:100]}...")
|
||||
except Exception as e:
|
||||
print(f"带父块上下文的搜索测试失败: {e}")
|
||||
|
||||
# 测试统一检索接口
|
||||
print("\n测试统一检索接口...")
|
||||
try:
|
||||
# 返回父块
|
||||
results_parent = await builder.retrieve("吕布", return_parent=True)
|
||||
print(f"返回父块的结果数量: {len(results_parent)}")
|
||||
|
||||
# 返回子块
|
||||
results_child = await builder.retrieve("吕布", return_parent=False)
|
||||
print(f"返回子块的结果数量: {len(results_child)}")
|
||||
except Exception as e:
|
||||
print(f"统一检索接口测试失败: {e}")
|
||||
|
||||
# 关闭资源
|
||||
builder.close()
|
||||
print("\n测试完成")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_index_builder())
|
||||
Reference in New Issue
Block a user