文件变更
This commit is contained in:
@@ -39,7 +39,7 @@ from .retriever import (
|
||||
create_hybrid_retriever,
|
||||
create_qdrant_client,
|
||||
)
|
||||
from .reranker import CrossEncoderReranker
|
||||
from .reranker import LLaMaCPPReranker
|
||||
from .query_transform import MultiQueryGenerator
|
||||
from .fusion import reciprocal_rank_fusion
|
||||
from .pipeline import RAGPipeline
|
||||
@@ -53,7 +53,7 @@ __all__ = [
|
||||
"create_qdrant_client",
|
||||
|
||||
# 重排序器
|
||||
"CrossEncoderReranker",
|
||||
"LLaMaCPPReranker",
|
||||
|
||||
# 查询改写生成器
|
||||
"MultiQueryGenerator",
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# rag/pipeline.py
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import List, Optional
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
@@ -23,7 +24,6 @@ class RAGPipeline:
|
||||
llm: BaseLanguageModel,
|
||||
num_queries: int = 3,
|
||||
rerank_top_n: int = 5,
|
||||
rerank_model: str = "BAAI/bge-reranker-base",
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -41,9 +41,9 @@ class RAGPipeline:
|
||||
# 初始化组件
|
||||
self.query_generator = MultiQueryGenerator(llm=llm, num_queries=num_queries)
|
||||
self.reranker = LLaMaCPPReranker(
|
||||
base_url="http://127.0.0.1:8083",
|
||||
base_url=os.getenv("LLAMACPP_RERANKER_URL", "http://127.0.0.1:8083"),
|
||||
api_key=os.getenv("LLAMACPP_API_KEY", "huang1998"),
|
||||
top_n=rerank_top_n,
|
||||
api_key="huang1998"
|
||||
)
|
||||
|
||||
async def aretrieve(self, query: str) -> List[Document]:
|
||||
@@ -68,9 +68,9 @@ class RAGPipeline:
|
||||
fused_docs = reciprocal_rank_fusion(doc_lists)
|
||||
|
||||
# Step 4: 重排序
|
||||
if self.reranker.model is not None:
|
||||
try:
|
||||
final_docs = self.reranker.compress_documents(fused_docs, query)
|
||||
else:
|
||||
except Exception:
|
||||
# 若重排序器不可用,直接返回融合后的前 N 条
|
||||
final_docs = fused_docs[:self.rerank_top_n]
|
||||
|
||||
|
||||
@@ -2,32 +2,33 @@
|
||||
重排序器模块 (适配版)
|
||||
使用远程 llama.cpp 服务 (兼容 OpenAI Rerank API) 替代本地 Cross-Encoder
|
||||
"""
|
||||
import os
|
||||
import requests
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
from langchain_core.documents import Document
|
||||
|
||||
class LLaMaCPPReranker:
|
||||
"""使用远程 llama.cpp 服务对检索结果重排序。"""
|
||||
|
||||
def __init__(self,
|
||||
base_url: str = "http://127.0.0.1:8083",
|
||||
base_url: str,
|
||||
api_key: str,
|
||||
top_n: int = 5,
|
||||
api_key: str = "huang1998", # 你设置的 LLAMA_ARG_API_KEY
|
||||
timeout: int = 60):
|
||||
"""
|
||||
初始化远程重排序器
|
||||
|
||||
Args:
|
||||
base_url: llama.cpp 服务的地址和端口。
|
||||
base_url: llama.cpp 服务的地址和端口,默认为环境变量 LLAMACPP_RERANKER_URL 或 "http://127.0.0.1:8083"。
|
||||
top_n: 返回前 N 个结果。
|
||||
api_key: 在容器中设置的 API 密钥。
|
||||
api_key: API 密钥,默认为环境变量 LLAMACPP_API_KEY 或 "huang1998"。
|
||||
timeout: 请求超时时间(秒)。
|
||||
"""
|
||||
self.base_url = base_url.rstrip('/')
|
||||
self.base_url = base_url
|
||||
self.api_key = api_key
|
||||
self.top_n = top_n
|
||||
self.api_key = api_key
|
||||
self.timeout = timeout
|
||||
self.endpoint = f"{self.base_url}/v1/rerank"
|
||||
self.endpoint = f"{self.base_url}/rerank"
|
||||
|
||||
def compress_documents(
|
||||
self, documents: List[Document], query: str
|
||||
|
||||
@@ -4,74 +4,12 @@ RAG 工具模块
|
||||
将检索功能封装为 LangChain Tool,供 Agent 调用。
|
||||
采用固定流水线:多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档。
|
||||
"""
|
||||
|
||||
from typing import Optional, Callable
|
||||
from langchain_core.tools import tool
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
|
||||
from .pipeline import RAGPipeline
|
||||
|
||||
|
||||
def create_rag_tool(
|
||||
retriever: BaseRetriever,
|
||||
llm: BaseLanguageModel,
|
||||
num_queries: int = 3,
|
||||
rerank_top_n: int = 5,
|
||||
collection_name: str = "rag_documents",
|
||||
) -> Callable:
|
||||
"""
|
||||
创建一个配置好的 RAG 检索工具(异步)。
|
||||
|
||||
Args:
|
||||
retriever: 基础检索器(例如 ParentDocumentRetriever 实例)
|
||||
llm: 用于多路查询改写的语言模型
|
||||
num_queries: 生成查询变体数量
|
||||
rerank_top_n: 最终返回的文档数量
|
||||
collection_name: 集合名称(仅用于日志/描述)
|
||||
|
||||
Returns:
|
||||
LangChain Tool 可调用对象(异步)
|
||||
"""
|
||||
# 初始化流水线(所有组件一次创建,后续复用)
|
||||
pipeline = RAGPipeline(
|
||||
retriever=retriever,
|
||||
llm=llm,
|
||||
num_queries=num_queries,
|
||||
rerank_top_n=rerank_top_n,
|
||||
)
|
||||
|
||||
@tool
|
||||
async def search_knowledge_base(query: str) -> str:
|
||||
"""在知识库中搜索与查询相关的文档片段。
|
||||
|
||||
该工具会:
|
||||
1. 将用户问题改写成多个不同角度的查询
|
||||
2. 并行检索每个查询的相关父文档
|
||||
3. 使用倒数排名融合(RRF)合并结果
|
||||
4. 用 Cross-Encoder 重排序模型精选最相关的片段
|
||||
|
||||
适用于需要精确、全面答案的事实性问题或背景知识查询。
|
||||
|
||||
Args:
|
||||
query: 用户提出的问题或查询字符串
|
||||
|
||||
Returns:
|
||||
格式化后的相关文档内容,若无结果则返回提示信息。
|
||||
"""
|
||||
try:
|
||||
documents = await pipeline.aretrieve(query)
|
||||
if not documents:
|
||||
return f"在知识库 '{collection_name}' 中未找到与 '{query}' 相关的信息。"
|
||||
|
||||
context = pipeline.format_context(documents)
|
||||
return context
|
||||
except Exception as e:
|
||||
return f"检索过程中发生错误: {str(e)}"
|
||||
|
||||
return search_knowledge_base
|
||||
|
||||
|
||||
def create_rag_tool_sync(
|
||||
retriever: BaseRetriever,
|
||||
llm: BaseLanguageModel,
|
||||
|
||||
Reference in New Issue
Block a user