Files
ailine/app/rag/example.py

203 lines
5.6 KiB
Python
Raw Normal View History

2026-04-18 16:31:48 +08:00
#!/usr/bin/env python3
"""
RAG 系统使用示例
演示如何使用 app/rag 模块进行知识检索
"""
import sys
import os
2026-04-19 22:01:55 +08:00
from dotenv import load_dotenv
# 加载环境变量
load_dotenv()
2026-04-18 16:31:48 +08:00
# 添加项目根目录到路径
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
from langchain_openai import OpenAIEmbeddings
from langchain_community.llms import VLLMOpenAI
def setup_environment():
"""设置环境变量"""
# 设置 Qdrant 连接信息(根据实际情况修改)
os.environ.setdefault("QDRANT_URL", "http://115.190.121.151:6333")
2026-04-19 22:01:55 +08:00
# 设置 Qdrant API 密钥(根据实际情况修改)
os.environ.setdefault("QDRANT_API_KEY", "your-api-key-here")
2026-04-18 16:31:48 +08:00
# 如果需要 API 密钥,请设置 QDRANT_API_KEY
print("环境变量已设置")
print(f"QDRANT_URL: {os.environ.get('QDRANT_URL')}")
2026-04-19 22:01:55 +08:00
print(f"QDRANT_API_KEY: {'***' if os.environ.get('QDRANT_API_KEY') else '未设置'}")
2026-04-18 16:31:48 +08:00
def demonstrate_basic_rag():
"""演示基础 RAG 功能"""
print("\n" + "="*60)
print("演示: 基础 RAG 检索 (Level 1)")
print("="*60)
2026-04-19 22:01:55 +08:00
# 创建嵌入模型(使用本地 LlamaCpp 模型)
from rag_core import LlamaCppEmbedder
embedder = LlamaCppEmbedder()
embeddings = embedder.as_langchain_embeddings()
2026-04-18 16:31:48 +08:00
# 创建 RAG 流水线
2026-04-19 22:01:55 +08:00
from app.rag import RAGPipeline, RAGLevel
2026-04-18 16:31:48 +08:00
pipeline = RAGPipeline(
embeddings=embeddings,
2026-04-19 22:01:55 +08:00
config={
"collection_name": "rag_documents", # 你的集合名称
"rag_level": RAGLevel.BASIC.value,
}
2026-04-18 16:31:48 +08:00
)
# 示例查询
2026-04-19 22:01:55 +08:00
query = "吕布"
2026-04-18 16:31:48 +08:00
print(f"\n查询: {query}")
try:
2026-04-19 22:01:55 +08:00
documents = pipeline.retrieve(query)
print(f"找到 {len(documents)} 个相关文档")
2026-04-18 16:31:48 +08:00
# 格式化上下文
2026-04-19 22:01:55 +08:00
context = pipeline.format_context(documents)
2026-04-18 16:31:48 +08:00
print(f"\n上下文预览:\n{context[:500]}...")
except Exception as e:
print(f"检索失败: {e}")
print("请确保 Qdrant 服务正常运行且集合存在")
def demonstrate_hybrid_rag():
"""演示混合 RAG 功能"""
print("\n" + "="*60)
print("演示: 混合 RAG 检索 (Level 2)")
print("="*60)
2026-04-19 22:01:55 +08:00
from rag_core import LlamaCppEmbedder
embedder = LlamaCppEmbedder()
embeddings = embedder.as_langchain_embeddings()
2026-04-18 16:31:48 +08:00
2026-04-19 22:01:55 +08:00
from app.rag import RAGPipeline, RAGLevel
2026-04-18 16:31:48 +08:00
pipeline = RAGPipeline(
embeddings=embeddings,
2026-04-19 22:01:55 +08:00
config={
"collection_name": "rag_documents",
"rag_level": RAGLevel.RERANK.value,
"rerank_top_n": 5,
}
2026-04-18 16:31:48 +08:00
)
2026-04-19 22:01:55 +08:00
query = "吕布"
2026-04-18 16:31:48 +08:00
print(f"\n查询: {query}")
try:
2026-04-19 22:01:55 +08:00
documents = pipeline.retrieve(query)
print(f"找到 {len(documents)} 个重排序后的文档")
# 格式化上下文
context = pipeline.format_context(documents)
print(f"\n上下文预览:\n{context[:500]}...")
2026-04-18 16:31:48 +08:00
except Exception as e:
print(f"检索失败: {e}")
def demonstrate_rag_fusion():
"""演示 RAG-Fusion 功能"""
print("\n" + "="*60)
print("演示: RAG-Fusion (Level 3)")
print("="*60)
2026-04-19 22:01:55 +08:00
from rag_core import LlamaCppEmbedder
embedder = LlamaCppEmbedder()
embeddings = embedder.as_langchain_embeddings()
2026-04-18 16:31:48 +08:00
2026-04-19 22:01:55 +08:00
# 创建语言模型用于查询改写(使用 OpenAI 兼容的本地模型)
from langchain_openai import ChatOpenAI
llm = ChatOpenAI(
2026-04-18 16:31:48 +08:00
openai_api_base="http://localhost:8000/v1",
openai_api_key="no-key-needed",
2026-04-19 22:01:55 +08:00
model="Qwen2.5-7B-Instruct", # 你的本地模型
2026-04-18 16:31:48 +08:00
temperature=0.3,
max_tokens=512,
)
2026-04-19 22:01:55 +08:00
from app.rag import RAGPipeline, RAGLevel
2026-04-18 16:31:48 +08:00
pipeline = RAGPipeline(
embeddings=embeddings,
llm=llm,
2026-04-19 22:01:55 +08:00
config={
"collection_name": "rag_documents",
"rag_level": RAGLevel.FUSION.value,
"num_queries": 3,
}
2026-04-18 16:31:48 +08:00
)
2026-04-19 22:01:55 +08:00
query = "吕布"
2026-04-18 16:31:48 +08:00
print(f"\n查询: {query}")
try:
2026-04-19 22:01:55 +08:00
documents = pipeline.retrieve(query)
print(f"找到 {len(documents)} 个文档 (经过多路查询改写和重排序)")
# 格式化上下文
context = pipeline.format_context(documents)
print(f"\n上下文预览:\n{context[:500]}...")
2026-04-18 16:31:48 +08:00
except Exception as e:
print(f"检索失败: {e}")
def demonstrate_agentic_rag():
"""演示 Agentic RAG 功能"""
print("\n" + "="*60)
print("演示: Agentic RAG (Level 4)")
print("="*60)
2026-04-19 22:01:55 +08:00
from app.rag import search_knowledge_base_sync
2026-04-18 16:31:48 +08:00
try:
# 演示工具调用
2026-04-19 22:01:55 +08:00
print("工具调用示例:")
response = search_knowledge_base_sync("吕布")
2026-04-18 16:31:48 +08:00
print(f"工具响应预览: {response[:200]}...")
except Exception as e:
2026-04-19 22:01:55 +08:00
print(f"工具调用失败: {e}")
2026-04-18 16:31:48 +08:00
import traceback
traceback.print_exc()
def main():
"""主函数"""
print("RAG 系统演示")
print("="*60)
# 设置环境
setup_environment()
2026-04-19 22:01:55 +08:00
# 演示基础功能
2026-04-18 16:31:48 +08:00
demonstrate_basic_rag()
demonstrate_hybrid_rag()
2026-04-19 22:01:55 +08:00
# demonstrate_rag_fusion() # 需要本地 LLM 服务
# demonstrate_agentic_rag() # 需要本地 LLM 服务
2026-04-18 16:31:48 +08:00
print("\n" + "="*60)
print("演示完成!")
print("="*60)
print("\n使用说明:")
print("1. 确保 Qdrant 服务运行且集合已创建")
2026-04-19 22:01:55 +08:00
print("2. 已使用本地 LlamaCpp 嵌入模型")
print("3. 在 Agent 系统中导入并使用 app.rag.tools.search_knowledge_base")
2026-04-18 16:31:48 +08:00
print("4. 将工具绑定到你的 Agent 模型")
if __name__ == "__main__":
main()