From 5e9bbd519fa106ca035294bb9327ff7f91526d9e Mon Sep 17 00:00:00 2001 From: root <953994191@qq.com> Date: Tue, 21 Apr 2026 20:49:10 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B5=8B=E8=AF=95=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/rag_core/config.py | 3 +- backend/rag_core/embedders.py | 10 ++++-- rag_indexer/__init__.py | 2 +- rag_indexer/index_builder.py | 1 - backend/app/rag/test.py => test/test_rag.py | 18 ++++------ test/test_rag_indexer_result.py | 37 +-------------------- 6 files changed, 18 insertions(+), 53 deletions(-) rename backend/app/rag/test.py => test/test_rag.py (91%) diff --git a/backend/rag_core/config.py b/backend/rag_core/config.py index 0d73cdc..1a1ccd3 100644 --- a/backend/rag_core/config.py +++ b/backend/rag_core/config.py @@ -6,7 +6,8 @@ RAG Core 配置管理模块 """ import os - +import dotenv +dotenv.load_dotenv() # ========== 辅助函数:类型转换 ========== def _get_str(key: str) -> str | None: diff --git a/backend/rag_core/embedders.py b/backend/rag_core/embedders.py index 81c2267..91908df 100644 --- a/backend/rag_core/embedders.py +++ b/backend/rag_core/embedders.py @@ -21,6 +21,9 @@ class LlamaCppEmbedder: self.base_url = LLAMACPP_EMBEDDING_URL self.api_key = LLAMACPP_API_KEY self.model = model + print(f"初始化 base_url: { self.base_url}") + + def as_langchain_embeddings(self) -> Embeddings: """创建 LangChain 兼容的嵌入实例。""" @@ -41,13 +44,14 @@ class LlamaCppEmbedder: def _call_embedding_api(self, texts: List[str]) -> List[List[float]]: """直接调用 llama.cpp 嵌入 API。""" + headers = {"Content-Type": "application/json"} + if self.api_key: + headers["Authorization"] = f"Bearer {self.api_key}" + base = self.base_url.rstrip("/") if not base.endswith("/v1"): base = base + "/v1" - headers = {"Content-Type": "application/json"} - if self.api_key: - headers["Authorization"] = f"Bearer {self.api_key}" payload = { "input": texts, diff --git a/rag_indexer/__init__.py b/rag_indexer/__init__.py index 4a1e2e3..68088c1 100644 --- a/rag_indexer/__init__.py +++ b/rag_indexer/__init__.py @@ -42,7 +42,7 @@ import sys from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.parent / "backend")) -from rag_core import ( +from backend.rag_core import ( LlamaCppEmbedder, QdrantVectorStore, PostgresDocStore, diff --git a/rag_indexer/index_builder.py b/rag_indexer/index_builder.py index a348120..666fef7 100644 --- a/rag_indexer/index_builder.py +++ b/rag_indexer/index_builder.py @@ -94,7 +94,6 @@ class IndexBuilder: # 初始化向量存储 self.vector_store = QdrantVectorStore( collection_name=config.collection_name, - embeddings=self.embeddings, ) # 根据切分类型初始化相关组件 diff --git a/backend/app/rag/test.py b/test/test_rag.py similarity index 91% rename from backend/app/rag/test.py rename to test/test_rag.py index 5e5d19e..bcdd3ba 100644 --- a/backend/app/rag/test.py +++ b/test/test_rag.py @@ -18,18 +18,14 @@ from dotenv import load_dotenv load_dotenv() # 添加项目根目录到路径 -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../..")) - +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) +from pydantic import SecretStr +from langchain_openai import ChatOpenAI from rag_indexer.index_builder import IndexBuilderConfig from rag_indexer.splitters import SplitterType -from .pipeline import RAGPipeline -from .tools import create_rag_tool_sync -from pydantic import SecretStr -# 使用本地 LLM(通过 OpenAI 兼容接口) -from langchain_openai import ChatOpenAI -from rag_core.retriever_factory import create_parent_retriever - -load_dotenv() +from backend.app.rag.pipeline import RAGPipeline +from backend.app.rag.tools import create_rag_tool_sync +from backend.rag_core.retriever_factory import create_parent_retriever def create_llm(): """创建本地 vLLM 服务 LLM""" @@ -113,7 +109,7 @@ async def demonstrate_tool_creation(): collection_name="rag_documents", splitter_type=SplitterType.PARENT_CHILD, ) - retriever = retriever = create_parent_retriever(collection_name="rag_documents", search_k=5) + retriever = create_parent_retriever(collection_name="rag_documents", search_k=5) # 2. 创建 LLM llm = create_llm() diff --git a/test/test_rag_indexer_result.py b/test/test_rag_indexer_result.py index 70c7105..771a429 100644 --- a/test/test_rag_indexer_result.py +++ b/test/test_rag_indexer_result.py @@ -40,42 +40,7 @@ async def test_index_builder(): print(f"集合信息: {info}") else: print(f"测试文件不存在: {test_file}") - - # 测试搜索功能 - print("\n测试搜索功能...") - try: - results = builder.search("吕布", k=3) - print(f"搜索结果数量: {len(results)}") - for i, result in enumerate(results): - print(f"\n结果 {i+1}:") - print(f"内容: {result.page_content[:100]}...") - except Exception as e: - print(f"搜索测试失败: {e}") - - # 测试带父块上下文的搜索 - print("\n测试带父块上下文的搜索...") - try: - results = await builder.search_with_parent_context("吕布", k=3) - print(f"搜索结果数量: {len(results)}") - for i, result in enumerate(results): - print(f"\n结果 {i+1}:") - print(f"内容: {result.page_content[:100]}...") - except Exception as e: - print(f"带父块上下文的搜索测试失败: {e}") - - # 测试统一检索接口 - print("\n测试统一检索接口...") - try: - # 返回父块 - results_parent = await builder.retrieve("吕布", return_parent=True) - print(f"返回父块的结果数量: {len(results_parent)}") - - # 返回子块 - results_child = await builder.retrieve("吕布", return_parent=False) - print(f"返回子块的结果数量: {len(results_child)}") - except Exception as e: - print(f"统一检索接口测试失败: {e}") - + # 关闭资源 builder.close() print("\n测试完成")