This commit is contained in:
@@ -7,6 +7,10 @@ RAG 系统使用示例
|
||||
|
||||
import sys
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
|
||||
# 添加项目根目录到路径
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
@@ -19,10 +23,13 @@ def setup_environment():
|
||||
"""设置环境变量"""
|
||||
# 设置 Qdrant 连接信息(根据实际情况修改)
|
||||
os.environ.setdefault("QDRANT_URL", "http://115.190.121.151:6333")
|
||||
# 设置 Qdrant API 密钥(根据实际情况修改)
|
||||
os.environ.setdefault("QDRANT_API_KEY", "your-api-key-here")
|
||||
# 如果需要 API 密钥,请设置 QDRANT_API_KEY
|
||||
|
||||
print("环境变量已设置")
|
||||
print(f"QDRANT_URL: {os.environ.get('QDRANT_URL')}")
|
||||
print(f"QDRANT_API_KEY: {'***' if os.environ.get('QDRANT_API_KEY') else '未设置'}")
|
||||
|
||||
|
||||
def demonstrate_basic_rag():
|
||||
@@ -31,37 +38,32 @@ def demonstrate_basic_rag():
|
||||
print("演示: 基础 RAG 检索 (Level 1)")
|
||||
print("="*60)
|
||||
|
||||
# 创建嵌入模型(使用 OpenAI 兼容的本地模型)
|
||||
embeddings = OpenAIEmbeddings(
|
||||
openai_api_base="http://localhost:8000/v1", # 本地 VLLM 服务
|
||||
openai_api_key="no-key-needed",
|
||||
model="text-embedding-ada-002", # 假设的模型名称
|
||||
)
|
||||
# 创建嵌入模型(使用本地 LlamaCpp 模型)
|
||||
from rag_core import LlamaCppEmbedder
|
||||
embedder = LlamaCppEmbedder()
|
||||
embeddings = embedder.as_langchain_embeddings()
|
||||
|
||||
# 创建 RAG 流水线
|
||||
from app.rag import RAGPipeline, RAGConfig, RAGLevel
|
||||
|
||||
config = RAGConfig(
|
||||
collection_name="documents", # 你的集合名称
|
||||
rag_level=RAGLevel.BASIC,
|
||||
verbose=True,
|
||||
)
|
||||
from app.rag import RAGPipeline, RAGLevel
|
||||
|
||||
pipeline = RAGPipeline(
|
||||
embeddings=embeddings,
|
||||
config=config,
|
||||
config={
|
||||
"collection_name": "rag_documents", # 你的集合名称
|
||||
"rag_level": RAGLevel.BASIC.value,
|
||||
}
|
||||
)
|
||||
|
||||
# 示例查询
|
||||
query = "公司报销流程是什么?"
|
||||
query = "吕布"
|
||||
print(f"\n查询: {query}")
|
||||
|
||||
try:
|
||||
result = pipeline.retrieve(query)
|
||||
print(f"找到 {len(result.documents)} 个相关文档")
|
||||
documents = pipeline.retrieve(query)
|
||||
print(f"找到 {len(documents)} 个相关文档")
|
||||
|
||||
# 格式化上下文
|
||||
context = pipeline.format_context(result.documents)
|
||||
context = pipeline.format_context(documents)
|
||||
print(f"\n上下文预览:\n{context[:500]}...")
|
||||
|
||||
except Exception as e:
|
||||
@@ -75,34 +77,31 @@ def demonstrate_hybrid_rag():
|
||||
print("演示: 混合 RAG 检索 (Level 2)")
|
||||
print("="*60)
|
||||
|
||||
embeddings = OpenAIEmbeddings(
|
||||
openai_api_base="http://localhost:8000/v1",
|
||||
openai_api_key="no-key-needed",
|
||||
model="text-embedding-ada-002",
|
||||
)
|
||||
from rag_core import LlamaCppEmbedder
|
||||
embedder = LlamaCppEmbedder()
|
||||
embeddings = embedder.as_langchain_embeddings()
|
||||
|
||||
from app.rag import RAGPipeline, RAGConfig, RAGLevel
|
||||
|
||||
config = RAGConfig(
|
||||
collection_name="documents",
|
||||
rag_level=RAGLevel.HYBRID,
|
||||
dense_k=10,
|
||||
sparse_k=10,
|
||||
rerank_top_n=5,
|
||||
verbose=True,
|
||||
)
|
||||
from app.rag import RAGPipeline, RAGLevel
|
||||
|
||||
pipeline = RAGPipeline(
|
||||
embeddings=embeddings,
|
||||
config=config,
|
||||
config={
|
||||
"collection_name": "rag_documents",
|
||||
"rag_level": RAGLevel.RERANK.value,
|
||||
"rerank_top_n": 5,
|
||||
}
|
||||
)
|
||||
|
||||
query = "如何申请年假?"
|
||||
query = "吕布"
|
||||
print(f"\n查询: {query}")
|
||||
|
||||
try:
|
||||
result = pipeline.retrieve(query)
|
||||
print(f"找到 {len(result.documents)} 个重排序后的文档")
|
||||
documents = pipeline.retrieve(query)
|
||||
print(f"找到 {len(documents)} 个重排序后的文档")
|
||||
|
||||
# 格式化上下文
|
||||
context = pipeline.format_context(documents)
|
||||
print(f"\n上下文预览:\n{context[:500]}...")
|
||||
|
||||
except Exception as e:
|
||||
print(f"检索失败: {e}")
|
||||
@@ -114,42 +113,42 @@ def demonstrate_rag_fusion():
|
||||
print("演示: RAG-Fusion (Level 3)")
|
||||
print("="*60)
|
||||
|
||||
embeddings = OpenAIEmbeddings(
|
||||
openai_api_base="http://localhost:8000/v1",
|
||||
openai_api_key="no-key-needed",
|
||||
model="text-embedding-ada-002",
|
||||
)
|
||||
from rag_core import LlamaCppEmbedder
|
||||
embedder = LlamaCppEmbedder()
|
||||
embeddings = embedder.as_langchain_embeddings()
|
||||
|
||||
# 创建语言模型用于查询改写
|
||||
llm = VLLMOpenAI(
|
||||
# 创建语言模型用于查询改写(使用 OpenAI 兼容的本地模型)
|
||||
from langchain_openai import ChatOpenAI
|
||||
llm = ChatOpenAI(
|
||||
openai_api_base="http://localhost:8000/v1",
|
||||
openai_api_key="no-key-needed",
|
||||
model_name="Qwen2.5-7B-Instruct", # 你的本地模型
|
||||
model="Qwen2.5-7B-Instruct", # 你的本地模型
|
||||
temperature=0.3,
|
||||
max_tokens=512,
|
||||
)
|
||||
|
||||
from app.rag import RAGPipeline, RAGConfig, RAGLevel
|
||||
|
||||
config = RAGConfig(
|
||||
collection_name="documents",
|
||||
rag_level=RAGLevel.FUSION,
|
||||
num_queries=3,
|
||||
verbose=True,
|
||||
)
|
||||
from app.rag import RAGPipeline, RAGLevel
|
||||
|
||||
pipeline = RAGPipeline(
|
||||
embeddings=embeddings,
|
||||
llm=llm,
|
||||
config=config,
|
||||
config={
|
||||
"collection_name": "rag_documents",
|
||||
"rag_level": RAGLevel.FUSION.value,
|
||||
"num_queries": 3,
|
||||
}
|
||||
)
|
||||
|
||||
query = "项目上线需要哪些审批?"
|
||||
query = "吕布"
|
||||
print(f"\n查询: {query}")
|
||||
|
||||
try:
|
||||
result = pipeline.retrieve(query)
|
||||
print(f"找到 {len(result.documents)} 个文档 (经过多路查询改写和重排序)")
|
||||
documents = pipeline.retrieve(query)
|
||||
print(f"找到 {len(documents)} 个文档 (经过多路查询改写和重排序)")
|
||||
|
||||
# 格式化上下文
|
||||
context = pipeline.format_context(documents)
|
||||
print(f"\n上下文预览:\n{context[:500]}...")
|
||||
|
||||
except Exception as e:
|
||||
print(f"检索失败: {e}")
|
||||
@@ -161,44 +160,16 @@ def demonstrate_agentic_rag():
|
||||
print("演示: Agentic RAG (Level 4)")
|
||||
print("="*60)
|
||||
|
||||
embeddings = OpenAIEmbeddings(
|
||||
openai_api_base="http://localhost:8000/v1",
|
||||
openai_api_key="no-key-needed",
|
||||
model="text-embedding-ada-002",
|
||||
)
|
||||
|
||||
llm = VLLMOpenAI(
|
||||
openai_api_base="http://localhost:8000/v1",
|
||||
openai_api_key="no-key-needed",
|
||||
model_name="Qwen2.5-7B-Instruct",
|
||||
temperature=0.3,
|
||||
max_tokens=512,
|
||||
)
|
||||
|
||||
from app.rag import create_agentic_rag_pipeline
|
||||
from app.rag import search_knowledge_base_sync
|
||||
|
||||
try:
|
||||
# 创建 Agentic RAG 流水线
|
||||
agentic_rag = create_agentic_rag_pipeline(
|
||||
embeddings=embeddings,
|
||||
agent_llm=llm,
|
||||
config={
|
||||
"collection_name": "documents",
|
||||
"verbose": True,
|
||||
},
|
||||
)
|
||||
|
||||
print("Agentic RAG 流水线创建成功!")
|
||||
print(f"- 绑定的模型: {agentic_rag['llm']}")
|
||||
print(f"- RAG 工具: {agentic_rag['tool'].name}")
|
||||
|
||||
# 演示工具调用
|
||||
print("\n工具调用示例:")
|
||||
response = agentic_rag["tool"].invoke({"query": "员工福利有哪些?"})
|
||||
print("工具调用示例:")
|
||||
response = search_knowledge_base_sync("吕布")
|
||||
print(f"工具响应预览: {response[:200]}...")
|
||||
|
||||
except Exception as e:
|
||||
print(f"创建 Agentic RAG 失败: {e}")
|
||||
print(f"工具调用失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
@@ -211,11 +182,11 @@ def main():
|
||||
# 设置环境
|
||||
setup_environment()
|
||||
|
||||
# 演示各级功能
|
||||
# 演示基础功能
|
||||
demonstrate_basic_rag()
|
||||
demonstrate_hybrid_rag()
|
||||
demonstrate_rag_fusion()
|
||||
demonstrate_agentic_rag()
|
||||
# demonstrate_rag_fusion() # 需要本地 LLM 服务
|
||||
# demonstrate_agentic_rag() # 需要本地 LLM 服务
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("演示完成!")
|
||||
@@ -223,8 +194,8 @@ def main():
|
||||
|
||||
print("\n使用说明:")
|
||||
print("1. 确保 Qdrant 服务运行且集合已创建")
|
||||
print("2. 根据需要修改 embeddings 和 llm 配置")
|
||||
print("3. 在 Agent 系统中导入并使用 app.rag.tools.search_knowledge_base_tool")
|
||||
print("2. 已使用本地 LlamaCpp 嵌入模型")
|
||||
print("3. 在 Agent 系统中导入并使用 app.rag.tools.search_knowledge_base")
|
||||
print("4. 将工具绑定到你的 Agent 模型")
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user