Compare commits
6 Commits
8b354b7ccc
...
efa8bbcd03
| Author | SHA1 | Date | |
|---|---|---|---|
| efa8bbcd03 | |||
| aa8072369c | |||
| 5e9bbd519f | |||
| 37e86f3bb1 | |||
| e2eaac9498 | |||
| 08826c70a3 |
111
.env.docker
111
.env.docker
@@ -1,19 +1,62 @@
|
||||
# =============================================================================
|
||||
# Docker Compose 服务器部署配置
|
||||
# 用法: cp .env.docker .env 然后填入 API Key
|
||||
# Docker Compose 服务器部署配置模板
|
||||
# 用法: cp .env.docker .env 然后填入敏感密钥
|
||||
# =============================================================================
|
||||
|
||||
# ⭐ 敏感密钥配置(必须在此配置)
|
||||
# =============================================================================
|
||||
# AI 模型 API 密钥
|
||||
ZHIPUAI_API_KEY=your_zhipuai_api_key_here
|
||||
DEEPSEEK_API_KEY=your_deepseek_api_key_here
|
||||
# -----------------------------------------------------------------------------
|
||||
# AI 模型 API 密钥(⭐ 敏感配置 - 必须填入真实值)
|
||||
# -----------------------------------------------------------------------------
|
||||
ZHIPUAI_API_KEY=your_zhipuai_api_key_here # ⭐ 敏感密钥配置
|
||||
DEEPSEEK_API_KEY=your_deepseek_api_key_here # ⭐ 敏感密钥配置
|
||||
LLAMACPP_API_KEY=your_llamacpp_api_key_here # ⭐ 敏感密钥配置
|
||||
|
||||
# llama.cpp 服务认证 Token(与容器启动参数一致)
|
||||
LLAMACPP_API_KEY=token-abc123
|
||||
# -----------------------------------------------------------------------------
|
||||
# PostgreSQL 数据库配置(分离配置,易于管理)
|
||||
# -----------------------------------------------------------------------------
|
||||
DB_HOST=115.190.121.151
|
||||
DB_PORT=5432
|
||||
DB_USER=postgres
|
||||
DB_PASSWORD=your_db_password_here # ⭐ 敏感密钥配置
|
||||
DB_NAME=langgraph_db
|
||||
# 完整连接字符串(也支持直接配置,优先使用分离配置)
|
||||
DB_URI=postgresql://postgres:${DB_PASSWORD}@115.190.121.151:5432/langgraph_db?sslmode=disable
|
||||
|
||||
# ⭐ 日志调试配置(部署时可灵活调整)
|
||||
# =============================================================================
|
||||
# -----------------------------------------------------------------------------
|
||||
# Qdrant 向量数据库配置(URL + API密钥 配对)
|
||||
# -----------------------------------------------------------------------------
|
||||
QDRANT_URL=http://115.190.121.151:6333
|
||||
QDRANT_API_KEY=your_qdrant_api_key_here # ⭐ 敏感密钥配置
|
||||
QDRANT_COLLECTION_NAME=mem0_user_memories
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# llama.cpp 服务配置(URL + API密钥 配对)
|
||||
# -----------------------------------------------------------------------------
|
||||
# 主 LLM 服务 (Gemma-4-E2B GGUF) - 端口 18000 (Docker host 映射)
|
||||
VLLM_BASE_URL=http://host.docker.internal:18000/v1
|
||||
|
||||
# Embedding 服务 (Qwen3-Embedding-0.6B GGUF) - 端口 18001
|
||||
LLAMACPP_EMBEDDING_URL=http://host.docker.internal:18001/v1
|
||||
# LLAMACPP_API_KEY=your_llamacpp_api_key_here (已在上面配置)
|
||||
|
||||
# Reranker 服务 (bge-reranker-v2-m3) - 端口 18002
|
||||
LLAMACPP_RERANKER_URL=http://host.docker.internal:18002/v1
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# RAG 索引构建配置(非敏感,可直接使用)
|
||||
# -----------------------------------------------------------------------------
|
||||
RAG_COLLECTION_NAME=rag_documents
|
||||
RAG_CHUNK_SIZE=500
|
||||
RAG_CHUNK_OVERLAP=50
|
||||
RAG_PARENT_CHUNK_SIZE=1000
|
||||
RAG_CHILD_CHUNK_SIZE=200
|
||||
RAG_PARENT_CHUNK_OVERLAP=100
|
||||
RAG_CHILD_CHUNK_OVERLAP=20
|
||||
RAG_STRATEGY=parent-child
|
||||
RAG_STORAGE_TYPE=postgres
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# 日志调试配置(部署时可灵活调整)
|
||||
# -----------------------------------------------------------------------------
|
||||
# 日志级别:DEBUG, INFO, WARNING, ERROR, CRITICAL
|
||||
# 生产环境推荐 WARNING,排查问题时改为 DEBUG
|
||||
LOG_LEVEL=WARNING
|
||||
@@ -28,53 +71,13 @@ DEBUG=false
|
||||
# false: 关闭追踪,减少日志量
|
||||
ENABLE_GRAPH_TRACE=false
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# llama.cpp 服务配置
|
||||
# -----------------------------------------------------------------------------
|
||||
# 主 LLM 服务 (Gemma-4-E2B GGUF) - 端口 8081
|
||||
VLLM_BASE_URL=http://host.docker.internal:18000/v1
|
||||
|
||||
# Embedding 服务 (embeddinggemma-300M GGUF) - 端口 8082
|
||||
LLAMACPP_EMBEDDING_URL=http://host.docker.internal:18001/v1
|
||||
|
||||
# Reranker 服务 (bge-reranker-v2-m3) - 端口 8083
|
||||
LLAMACPP_RERANKER_URL=http://host.docker.internal:18002/v1
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Mem0 记忆层配置
|
||||
# -----------------------------------------------------------------------------
|
||||
# Qdrant 向量数据库(远程服务器上的独立容器)
|
||||
QDRANT_URL=http://115.190.121.151:6333
|
||||
QDRANT_COLLECTION_NAME=mem0_user_memories
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# 数据库配置
|
||||
# -----------------------------------------------------------------------------
|
||||
# PostgreSQL 连接字符串(远程服务器上的独立容器)
|
||||
DB_URI=postgresql://postgres:huang1998@115.190.121.151:5432/langgraph_db?sslmode=disable
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# 前端配置
|
||||
# -----------------------------------------------------------------------------
|
||||
# Docker Compose 内部网络,使用服务名 'backend'
|
||||
API_URL=http://backend:8083/chat
|
||||
|
||||
# ⭐ 前端通信地址(Docker 内部网络)
|
||||
# 注意:这里只需要域名和端口,不需要 /chat 路径
|
||||
- API_URL=http://backend:8083
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# 应用行为配置
|
||||
# -----------------------------------------------------------------------------
|
||||
MEMORY_SUMMARIZE_INTERVAL=10
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# unstructured 库 spaCy 模型配置
|
||||
# 前端配置
|
||||
# -----------------------------------------------------------------------------
|
||||
# 指定文档解析使用的语言: eng (英语) 或 zho (中文)
|
||||
UNSTRUCTURED_LANGUAGE=zho
|
||||
|
||||
# 指定 spaCy 模型名称(需与 UNSTRUCTURED_LANGUAGE 对应)
|
||||
# eng -> en_core_web_sm
|
||||
# zho -> zh_core_web_sm
|
||||
SPACY_MODEL=zh_core_web_sm
|
||||
# Docker Compose 内部网络,使用服务名 'backend'
|
||||
API_URL=http://backend:8079/chat
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -17,6 +17,8 @@
|
||||
!rag_indexer/**
|
||||
!docker/
|
||||
!docker/**
|
||||
!test/
|
||||
!test/**
|
||||
!.gitea/
|
||||
!.gitea/**
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
AI Agent 应用模块
|
||||
"""
|
||||
|
||||
from ..agent import AIAgentService
|
||||
from ..graph.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
|
||||
from .agent.service import AIAgentService
|
||||
from .graph.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
|
||||
|
||||
__all__ = ["AIAgentService", "AVAILABLE_TOOLS", "TOOLS_BY_NAME"]
|
||||
|
||||
@@ -1,50 +1,91 @@
|
||||
"""
|
||||
环境变量集中管理模块
|
||||
所有配置项统一定义,避免散落在各个文件中
|
||||
配置分组:相关配置放在一起,URL 和 API Key 配对
|
||||
所有配置直接从环境变量读取,无默认值,避免配置混乱
|
||||
需要类型转换的配置在此处理
|
||||
"""
|
||||
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# ========== 辅助函数:类型转换 ==========
|
||||
def _get_str(key: str) -> str | None:
|
||||
"""获取字符串配置"""
|
||||
return os.getenv(key)
|
||||
|
||||
|
||||
def _get_int(key: str) -> int | None:
|
||||
"""获取整数配置,自动转换"""
|
||||
value = os.getenv(key)
|
||||
if value is not None:
|
||||
try:
|
||||
return int(value)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def _get_bool(key: str) -> bool | None:
|
||||
"""获取布尔配置,自动转换"""
|
||||
value = os.getenv(key)
|
||||
if value is not None:
|
||||
return value.lower() in ("true", "1", "yes", "on")
|
||||
return None
|
||||
|
||||
|
||||
# ========== 第三方 API 密钥 ==========
|
||||
ZHIPUAI_API_KEY = _get_str("ZHIPUAI_API_KEY")
|
||||
DEEPSEEK_API_KEY = _get_str("DEEPSEEK_API_KEY")
|
||||
|
||||
|
||||
# ========== llama.cpp 服务配置(URL + API密钥 配对) ==========
|
||||
# 主 LLM 服务
|
||||
VLLM_BASE_URL = _get_str("VLLM_BASE_URL")
|
||||
LLM_API_KEY = _get_str("LLAMACPP_API_KEY")
|
||||
|
||||
# Embedding 服务 (用于 Mem0 的向量化)
|
||||
LLAMACPP_EMBEDDING_URL = _get_str("LLAMACPP_EMBEDDING_URL")
|
||||
LLAMACPP_API_KEY = _get_str("LLAMACPP_API_KEY")
|
||||
|
||||
# Reranker 服务
|
||||
LLAMACPP_RERANKER_URL = _get_str("LLAMACPP_RERANKER_URL")
|
||||
|
||||
|
||||
# ========== Qdrant 向量数据库配置(URL + API密钥 配对) ==========
|
||||
QDRANT_URL = _get_str("QDRANT_URL")
|
||||
QDRANT_API_KEY = _get_str("QDRANT_API_KEY")
|
||||
QDRANT_COLLECTION_NAME = _get_str("QDRANT_COLLECTION_NAME")
|
||||
|
||||
|
||||
# ========== PostgreSQL 数据库配置(分离配置 + 完整URI) ==========
|
||||
# 分离配置(优先使用)
|
||||
DB_HOST = _get_str("DB_HOST")
|
||||
DB_PORT = _get_int("DB_PORT")
|
||||
DB_USER = _get_str("DB_USER")
|
||||
DB_PASSWORD = _get_str("DB_PASSWORD")
|
||||
DB_NAME = _get_str("DB_NAME")
|
||||
|
||||
# 完整连接字符串(直接从环境变量读取)
|
||||
DB_URI = _get_str("DB_URI")
|
||||
|
||||
|
||||
# ========== 后端服务配置 ==========
|
||||
BACKEND_PORT = _get_int("BACKEND_PORT")
|
||||
|
||||
|
||||
# ========== Mem0 记忆层配置 ==========
|
||||
# 记忆提取间隔:每 N 轮对话生成一次摘要
|
||||
MEMORY_SUMMARIZE_INTERVAL = _get_int("MEMORY_SUMMARIZE_INTERVAL")
|
||||
|
||||
|
||||
# ========== Graph 执行追踪配置 ==========
|
||||
# 是否启用 Graph 流转追踪(通过环境变量控制)
|
||||
ENABLE_GRAPH_TRACE = os.getenv("ENABLE_GRAPH_TRACE", "true").lower() == "true"
|
||||
ENABLE_GRAPH_TRACE = _get_bool("ENABLE_GRAPH_TRACE")
|
||||
|
||||
# ========== 记忆提取配置 ==========
|
||||
# 记忆提取间隔:每 N 轮对话生成一次摘要
|
||||
MEMORY_SUMMARIZE_INTERVAL = int(os.getenv("MEMORY_SUMMARIZE_INTERVAL", "10"))
|
||||
|
||||
# ========== Mem0 记忆层配置 ==========
|
||||
# Qdrant 向量数据库地址
|
||||
QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333")
|
||||
QDRANT_COLLECTION_NAME = os.getenv("QDRANT_COLLECTION_NAME", "mem0_user_memories")
|
||||
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY", "your-qdrant-api-key")
|
||||
|
||||
# ========== llm 配置 ==========
|
||||
# LLM 模型配置
|
||||
VLLM_BASE_URL = os.getenv("VLLM_BASE_URL", "http://127.0.0.1:8081/v1")
|
||||
LLM_API_KEY = os.getenv("LLM_API_KEY", "your-ai-api-key")
|
||||
|
||||
# llama.cpp Embedding 服务地址 (用于 Mem0 的向量化)
|
||||
LLAMACPP_EMBEDDING_URL = os.getenv("LLAMACPP_EMBEDDING_URL", "http://127.0.0.1:8082/v1")
|
||||
LLAMACPP_API_KEY = os.getenv("LLAMACPP_API_KEY", "your-llamacpp-api-key")
|
||||
|
||||
# ========== 后端服务配置 ==========
|
||||
# 数据库连接字符串
|
||||
DB_URI = os.getenv(
|
||||
"DB_URI",
|
||||
"postgresql://postgres:huang1998@ai-postgres:5432/langgraph_db?sslmode=disable"
|
||||
)
|
||||
# 后端服务端口
|
||||
BACKEND_PORT = int(os.getenv("BACKEND_PORT", "8079"))
|
||||
|
||||
# ========== 日志配置 ==========
|
||||
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
|
||||
DEBUG = os.getenv("DEBUG", "false").lower() == "true"
|
||||
|
||||
# ========== Reranker 服务配置 ==========
|
||||
LLAMACPP_RERANKER_URL = os.getenv("LLAMACPP_RERANKER_URL", "http://127.0.0.1:8083")
|
||||
|
||||
# ========== 第三方 API 密钥 ==========
|
||||
ZHIPUAI_API_KEY = os.getenv("ZHIPUAI_API_KEY", "")
|
||||
DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY", "")
|
||||
LOG_LEVEL = _get_str("LOG_LEVEL")
|
||||
DEBUG = _get_bool("DEBUG")
|
||||
|
||||
82
backend/app/graph/visualize_graph.py
Normal file
82
backend/app/graph/visualize_graph.py
Normal file
@@ -0,0 +1,82 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
LangGraph 图结构可视化脚本
|
||||
快速查看节点和边的连接关系
|
||||
运行方式:python backend/app/graph/visualize_graph.py
|
||||
"""
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 确定项目根目录(Agent1 目录)
|
||||
# 当前文件位置:backend/app/graph/visualize_graph.py
|
||||
# 向上 4 级到 Agent1
|
||||
PROJECT_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
BACKEND_DIR = PROJECT_ROOT / "backend"
|
||||
|
||||
# 关键:把 backend 目录加入 sys.path,这样才能找到 rag_core
|
||||
# 注意:这只对直接运行脚本有效,对 -m 方式无效(因为 -m 方式在脚本运行前就导入了)
|
||||
if str(BACKEND_DIR) not in sys.path:
|
||||
sys.path.insert(0, str(BACKEND_DIR))
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
load_dotenv(PROJECT_ROOT / ".env")
|
||||
|
||||
from app.agent.service import AIAgentService
|
||||
from app.config import DB_URI
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
import asyncio
|
||||
|
||||
|
||||
async def visualize_graph():
|
||||
"""可视化 LangGraph 结构"""
|
||||
print("=" * 80)
|
||||
print(" LangGraph 图结构可视化")
|
||||
print("=" * 80)
|
||||
print(f"项目根目录: {PROJECT_ROOT}")
|
||||
print(f"Backend 目录: {BACKEND_DIR}")
|
||||
|
||||
async with AsyncPostgresSaver.from_conn_string(DB_URI) as checkpointer:
|
||||
await checkpointer.setup()
|
||||
|
||||
# 创建服务实例
|
||||
print("\n正在初始化 Agent 服务...")
|
||||
agent_service = AIAgentService(checkpointer)
|
||||
await agent_service.initialize()
|
||||
|
||||
for model_name, graph in agent_service.graphs.items():
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f" 模型: {model_name}")
|
||||
print(f"{'=' * 80}")
|
||||
|
||||
# 获取图结构
|
||||
graph_structure = graph.get_graph()
|
||||
|
||||
# 1. 直接打印节点和边
|
||||
print("\n[1] 节点列表:")
|
||||
print("-" * 80)
|
||||
for node_id, node in graph_structure.nodes.items():
|
||||
print(f" - {node_id}: {node.name}")
|
||||
|
||||
print("\n[2] 边列表:")
|
||||
print("-" * 80)
|
||||
for edge in graph_structure.edges:
|
||||
print(f" {edge.source} --> {edge.target}")
|
||||
|
||||
# 3. ASCII 字符画(需要 grandalf)
|
||||
print("\n[3] ASCII 字符画:")
|
||||
print("-" * 80)
|
||||
try:
|
||||
print(graph_structure.draw_ascii())
|
||||
except Exception as e:
|
||||
print(f"⚠️ ASCII 绘制失败: {e}")
|
||||
|
||||
# 4. Mermaid 源码
|
||||
print("\n[4] Mermaid 源码 (可复制到 https://mermaid.live/):")
|
||||
print("-" * 80)
|
||||
print(graph_structure.draw_mermaid())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(visualize_graph())
|
||||
@@ -7,14 +7,6 @@ import os
|
||||
from .config import LOG_LEVEL, DEBUG
|
||||
import logging
|
||||
from typing import Any
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 先加载环境变量
|
||||
load_dotenv()
|
||||
|
||||
# 从环境变量读取日志级别,默认 INFO
|
||||
|
||||
|
||||
# 根据环境变量控制是否显示详细调试信息
|
||||
DEBUG_MODE = DEBUG
|
||||
|
||||
|
||||
@@ -37,7 +37,6 @@ RAG 检索与生成模块
|
||||
from .retriever import (
|
||||
create_base_retriever,
|
||||
create_hybrid_retriever,
|
||||
create_qdrant_client,
|
||||
)
|
||||
from .reranker import LLaMaCPPReranker
|
||||
from .query_transform import MultiQueryGenerator
|
||||
@@ -50,7 +49,6 @@ __all__ = [
|
||||
# 检索器工厂函数
|
||||
"create_base_retriever",
|
||||
"create_hybrid_retriever",
|
||||
"create_qdrant_client",
|
||||
|
||||
# 重排序器
|
||||
"LLaMaCPPReranker",
|
||||
|
||||
@@ -25,66 +25,25 @@ Qdrant 向量检索器模块
|
||||
>>> docs = retriever.invoke("什么是 RAG?")
|
||||
"""
|
||||
|
||||
from typing import Optional, Dict, Any
|
||||
from typing import Dict, Any
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||
from langchain_qdrant import QdrantVectorStore
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
|
||||
from rag_core import QDRANT_URL, QDRANT_API_KEY
|
||||
from rag_core import QDRANT_URL, QDRANT_API_KEY, LlamaCppEmbedder
|
||||
from rag_core.client import create_qdrant_client as create_core_qdrant_client
|
||||
|
||||
# 模块级常量
|
||||
DEFAULT_SEARCH_K = 20
|
||||
DEFAULT_SCORE_THRESHOLD = 0.3
|
||||
|
||||
|
||||
def create_qdrant_client(
|
||||
url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
timeout: int = 30,
|
||||
) -> QdrantClient:
|
||||
"""
|
||||
创建并返回一个配置好的 Qdrant 客户端。
|
||||
|
||||
优先使用传入参数,若未提供则回退到环境变量 QDRANT_URL 和 QDRANT_API_KEY。
|
||||
|
||||
Args:
|
||||
url: Qdrant 服务地址,例如 "http://localhost:6333"。
|
||||
默认从环境变量 QDRANT_URL 读取。
|
||||
api_key: API 密钥(若 Qdrant 启用了认证)。
|
||||
默认从环境变量 QDRANT_API_KEY 读取。
|
||||
timeout: 请求超时时间(秒),默认 30 秒。
|
||||
|
||||
Returns:
|
||||
配置好的 QdrantClient 实例。
|
||||
|
||||
Raises:
|
||||
ValueError: 如果 url 为空且环境变量也未设置。
|
||||
"""
|
||||
effective_url = url or QDRANT_URL
|
||||
if not effective_url:
|
||||
raise ValueError(
|
||||
"Qdrant URL 未提供,请设置参数 url 或环境变量 QDRANT_URL"
|
||||
)
|
||||
|
||||
effective_api_key = api_key or QDRANT_API_KEY
|
||||
|
||||
client_kwargs = {
|
||||
"url": effective_url,
|
||||
"timeout": timeout,
|
||||
}
|
||||
if effective_api_key:
|
||||
client_kwargs["api_key"] = effective_api_key
|
||||
|
||||
return QdrantClient(**client_kwargs)
|
||||
|
||||
|
||||
def create_base_retriever(
|
||||
collection_name: str,
|
||||
embeddings: Embeddings,
|
||||
search_kwargs: Optional[Dict[str, Any]] = None,
|
||||
client: Optional[QdrantClient] = None,
|
||||
search_kwargs: Dict[str, Any] | None = None,
|
||||
client: QdrantClient | None = None,
|
||||
) -> BaseRetriever:
|
||||
"""
|
||||
创建基础向量检索器(仅稠密向量检索)。
|
||||
@@ -94,7 +53,6 @@ def create_base_retriever(
|
||||
|
||||
Args:
|
||||
collection_name: Qdrant 集合名称(需预先创建并索引)。
|
||||
embeddings: LangChain 兼容的嵌入模型实例。
|
||||
search_kwargs: 搜索参数,可包含:
|
||||
- k (int): 返回的文档数量,默认 20。
|
||||
- score_threshold (float): 相似度阈值,仅返回高于此分数的文档。
|
||||
@@ -108,6 +66,10 @@ def create_base_retriever(
|
||||
Raises:
|
||||
ValueError: 如果集合不存在或嵌入模型无效。
|
||||
"""
|
||||
# 嵌入模型
|
||||
embedder = LlamaCppEmbedder()
|
||||
embeddings = embedder.as_langchain_embeddings()
|
||||
|
||||
# 合并默认搜索参数
|
||||
merged_search_kwargs = {"k": DEFAULT_SEARCH_K}
|
||||
if search_kwargs:
|
||||
@@ -115,7 +77,7 @@ def create_base_retriever(
|
||||
|
||||
# 创建或复用 Qdrant 客户端
|
||||
if client is None:
|
||||
client = create_qdrant_client()
|
||||
client = create_core_qdrant_client()
|
||||
|
||||
# 验证集合是否存在(可选,便于提前发现问题)
|
||||
try:
|
||||
@@ -140,11 +102,10 @@ def create_base_retriever(
|
||||
|
||||
def create_hybrid_retriever(
|
||||
collection_name: str,
|
||||
embeddings: Embeddings,
|
||||
dense_k: int = 10,
|
||||
sparse_k: int = 10,
|
||||
score_threshold: Optional[float] = DEFAULT_SCORE_THRESHOLD,
|
||||
client: Optional[QdrantClient] = None,
|
||||
score_threshold: float | None = DEFAULT_SCORE_THRESHOLD,
|
||||
client: QdrantClient | None = None,
|
||||
) -> BaseRetriever:
|
||||
"""
|
||||
创建混合检索器(稠密向量 + BM25 稀疏向量)。
|
||||
@@ -157,7 +118,6 @@ def create_hybrid_retriever(
|
||||
|
||||
Args:
|
||||
collection_name: Qdrant 集合名称。
|
||||
embeddings: 嵌入模型(用于稠密向量)。
|
||||
dense_k: 稠密向量检索返回数量,默认 10。
|
||||
sparse_k: 稀疏向量检索返回数量,默认 10。
|
||||
score_threshold: 相似度阈值,默认 0.3。
|
||||
@@ -177,7 +137,6 @@ def create_hybrid_retriever(
|
||||
# 复用基础检索器创建逻辑,只需调整搜索参数
|
||||
return create_base_retriever(
|
||||
collection_name=collection_name,
|
||||
embeddings=embeddings,
|
||||
search_kwargs=search_kwargs,
|
||||
client=client,
|
||||
)
|
||||
@@ -186,9 +145,8 @@ def create_hybrid_retriever(
|
||||
# 可选:提供异步友好的辅助函数
|
||||
async def acreate_base_retriever(
|
||||
collection_name: str,
|
||||
embeddings: Embeddings,
|
||||
search_kwargs: Optional[Dict[str, Any]] = None,
|
||||
client: Optional[QdrantClient] = None,
|
||||
search_kwargs: Dict[str, Any] | None = None,
|
||||
client: QdrantClient | None = None,
|
||||
) -> BaseRetriever:
|
||||
"""
|
||||
异步创建基础向量检索器(与同步版本功能相同)。
|
||||
@@ -196,4 +154,4 @@ async def acreate_base_retriever(
|
||||
适用于需要异步初始化的场景(例如在 FastAPI 启动事件中)。
|
||||
"""
|
||||
# 由于 QdrantVectorStore 初始化本身是同步的,这里直接调用同步版本即可
|
||||
return create_base_retriever(collection_name, embeddings, search_kwargs, client)
|
||||
return create_base_retriever(collection_name, search_kwargs, client)
|
||||
|
||||
@@ -5,9 +5,17 @@ RAG Core - 公共 RAG 组件包
|
||||
"""
|
||||
|
||||
from .embedders import LlamaCppEmbedder
|
||||
from .vector_store import QdrantVectorStore, QDRANT_URL, QDRANT_API_KEY
|
||||
from .vector_store import QdrantVectorStore
|
||||
from .store import PostgresDocStore, create_docstore
|
||||
from .retriever_factory import create_parent_retriever
|
||||
from .config import (
|
||||
QDRANT_URL,
|
||||
QDRANT_API_KEY,
|
||||
LLAMACPP_EMBEDDING_URL,
|
||||
LLAMACPP_API_KEY,
|
||||
DB_URI,
|
||||
DOCSTORE_URI,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
@@ -15,6 +23,10 @@ __all__ = [
|
||||
"QdrantVectorStore",
|
||||
"QDRANT_URL",
|
||||
"QDRANT_API_KEY",
|
||||
"LLAMACPP_EMBEDDING_URL",
|
||||
"LLAMACPP_API_KEY",
|
||||
"DB_URI",
|
||||
"DOCSTORE_URI",
|
||||
"PostgresDocStore",
|
||||
"create_docstore",
|
||||
"create_parent_retriever",
|
||||
|
||||
@@ -1,27 +1,30 @@
|
||||
# rag_core/client.py
|
||||
import os
|
||||
from .config import QDRANT_URL, QDRANT_API_KEY
|
||||
from typing import Optional
|
||||
from qdrant_client import QdrantClient
|
||||
|
||||
|
||||
def create_qdrant_client(timeout: int = 300) -> QdrantClient:
|
||||
"""
|
||||
创建并返回一个配置好的 Qdrant 客户端。
|
||||
|
||||
def create_qdrant_client(
|
||||
url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
timeout: int = 300, # 索引构建需要较长超时
|
||||
) -> QdrantClient:
|
||||
effective_url = url or QDRANT_URL
|
||||
effective_api_key = api_key or QDRANT_API_KEY
|
||||
Args:
|
||||
timeout: 请求超时时间(秒),默认 300 秒(索引构建需要较长超时)。
|
||||
|
||||
if not effective_url:
|
||||
Returns:
|
||||
配置好的 QdrantClient 实例。
|
||||
|
||||
Raises:
|
||||
ValueError: 如果 QDRANT_URL 未配置。
|
||||
"""
|
||||
if not QDRANT_URL:
|
||||
raise ValueError("Qdrant URL 未配置")
|
||||
|
||||
client_kwargs = {
|
||||
"url": effective_url,
|
||||
"url": QDRANT_URL,
|
||||
"timeout": timeout,
|
||||
}
|
||||
if effective_api_key:
|
||||
client_kwargs["api_key"] = effective_api_key
|
||||
if QDRANT_API_KEY:
|
||||
client_kwargs["api_key"] = QDRANT_API_KEY
|
||||
|
||||
return QdrantClient(**client_kwargs)
|
||||
return QdrantClient(**client_kwargs)
|
||||
|
||||
@@ -1,24 +1,55 @@
|
||||
"""
|
||||
RAG Core 配置管理模块
|
||||
集中管理所有环境变量配置项,避免散落在各个文件中
|
||||
所有配置直接从环境变量读取,无默认值,避免配置混乱
|
||||
需要类型转换的配置在此处理
|
||||
"""
|
||||
|
||||
import os
|
||||
import dotenv
|
||||
dotenv.load_dotenv()
|
||||
|
||||
# ========== 向量数据库配置 ==========
|
||||
QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333")
|
||||
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY", "")
|
||||
# ========== 辅助函数:类型转换 ==========
|
||||
def _get_str(key: str) -> str | None:
|
||||
"""获取字符串配置"""
|
||||
return os.getenv(key)
|
||||
|
||||
# ========== 嵌入服务配置 ==========
|
||||
LLAMACPP_EMBEDDING_URL = os.getenv("LLAMACPP_EMBEDDING_URL", "http://127.0.0.1:8082")
|
||||
LLAMACPP_API_KEY = os.getenv("LLAMACPP_API_KEY", "")
|
||||
|
||||
# ========== 文档存储配置 ==========
|
||||
DB_URI = os.getenv(
|
||||
"DB_URI",
|
||||
"postgresql://postgres:***@ai-postgres:5432/langgraph_db?sslmode=disable"
|
||||
)
|
||||
DOCSTORE_URI = os.getenv("DOCSTORE_URI", DB_URI)
|
||||
def _get_int(key: str) -> int | None:
|
||||
"""获取整数配置,自动转换"""
|
||||
value = os.getenv(key)
|
||||
if value is not None:
|
||||
try:
|
||||
return int(value)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
# ========== 向量数据库配置(URL + API密钥 配对) ==========
|
||||
QDRANT_URL = _get_str("QDRANT_URL")
|
||||
QDRANT_API_KEY = _get_str("QDRANT_API_KEY")
|
||||
|
||||
|
||||
# ========== 嵌入服务配置(URL + API密钥 配对) ==========
|
||||
LLAMACPP_EMBEDDING_URL = _get_str("LLAMACPP_EMBEDDING_URL")
|
||||
LLAMACPP_API_KEY = _get_str("LLAMACPP_API_KEY")
|
||||
|
||||
|
||||
# ========== 文档存储配置(分离配置 + 完整URI) ==========
|
||||
# 分离配置(优先使用)
|
||||
DB_HOST = _get_str("DB_HOST")
|
||||
DB_PORT = _get_int("DB_PORT")
|
||||
DB_USER = _get_str("DB_USER")
|
||||
DB_PASSWORD = _get_str("DB_PASSWORD")
|
||||
DB_NAME = _get_str("DB_NAME")
|
||||
|
||||
# 完整连接字符串(直接从环境变量读取)
|
||||
DB_URI = _get_str("DB_URI")
|
||||
|
||||
# 文档存储 URI(直接从环境变量读取,默认同 DB_URI)
|
||||
DOCSTORE_URI = _get_str("DOCSTORE_URI") or DB_URI
|
||||
|
||||
|
||||
# ========== 其他配置 ==========
|
||||
# 可以在此添加其他 RAG Core 专用的配置项
|
||||
# 可以在此添加其他 RAG Core 专用的配置项
|
||||
|
||||
@@ -5,22 +5,25 @@
|
||||
import os
|
||||
from .config import LLAMACPP_EMBEDDING_URL, LLAMACPP_API_KEY
|
||||
import httpx
|
||||
from typing import List, Optional
|
||||
from typing import List
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
|
||||
class LlamaCppEmbedder:
|
||||
"""通过 OpenAI 兼容 API 封装 llama.cpp 嵌入服务。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
model: str = "Qwen3-Embedding-0.6B-Q8_0",
|
||||
):
|
||||
self.base_url = base_url or LLAMACPP_EMBEDDING_URL
|
||||
self.api_key = api_key or LLAMACPP_API_KEY
|
||||
def __init__(self, model: str = "Qwen3-Embedding-0.6B-Q8_0"):
|
||||
"""
|
||||
Args:
|
||||
model: 嵌入模型名称,默认 "Qwen3-Embedding-0.6B-Q8_0"。
|
||||
"""
|
||||
self.base_url = LLAMACPP_EMBEDDING_URL
|
||||
self.api_key = LLAMACPP_API_KEY
|
||||
self.model = model
|
||||
print(f"初始化 base_url: { self.base_url}")
|
||||
|
||||
|
||||
|
||||
def as_langchain_embeddings(self) -> Embeddings:
|
||||
"""创建 LangChain 兼容的嵌入实例。"""
|
||||
@@ -30,7 +33,7 @@ class LlamaCppEmbedder:
|
||||
"""嵌入一批文档。"""
|
||||
return self._call_embedding_api(texts)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
def embed_query(self, text: str) -> List[List[float]]:
|
||||
"""嵌入单个查询。"""
|
||||
return self._call_embedding_api([text])[0]
|
||||
|
||||
@@ -41,13 +44,14 @@ class LlamaCppEmbedder:
|
||||
|
||||
def _call_embedding_api(self, texts: List[str]) -> List[List[float]]:
|
||||
"""直接调用 llama.cpp 嵌入 API。"""
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if self.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
|
||||
base = self.base_url.rstrip("/")
|
||||
if not base.endswith("/v1"):
|
||||
base = base + "/v1"
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if self.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
|
||||
payload = {
|
||||
"input": texts,
|
||||
@@ -70,6 +74,7 @@ class LlamaCppEmbedder:
|
||||
else:
|
||||
raise ValueError(f"未知的嵌入 API 响应格式: {data}")
|
||||
|
||||
|
||||
class _LlamaCppLangchainAdapter(Embeddings):
|
||||
"""将 LlamaCppEmbedder 适配为 LangChain Embeddings 接口。"""
|
||||
|
||||
@@ -79,5 +84,5 @@ class _LlamaCppLangchainAdapter(Embeddings):
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
return self._embedder.embed_documents(texts)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
return self._embedder.embed_query(text)
|
||||
def embed_query(self, text: str) -> List[List[float]]:
|
||||
return self._embedder.embed_query(text)
|
||||
|
||||
@@ -1,38 +1,46 @@
|
||||
# rag_core/retriever_factory.py
|
||||
# rag_core/retriever_factory.py
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_classic.retrievers import ParentDocumentRetriever
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
from typing import Optional
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.stores import BaseStore
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
|
||||
from langchain_classic.retrievers import ParentDocumentRetriever
|
||||
from langchain_core.stores import BaseStore
|
||||
|
||||
from rag_core import LlamaCppEmbedder, QdrantVectorStore, create_docstore
|
||||
|
||||
|
||||
def create_parent_retriever(
|
||||
collection_name: str = "rag_documents",
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
parent_splitter: Optional[TextSplitter] = None,
|
||||
child_splitter: Optional[TextSplitter] = None,
|
||||
docstore: Optional[BaseStore] = None,
|
||||
parent_splitter: TextSplitter | None = None,
|
||||
child_splitter: TextSplitter | None = None,
|
||||
docstore: BaseStore | None = None,
|
||||
search_k: int = 5,
|
||||
# 若未传入切分器,则用以下参数创建默认切分器
|
||||
parent_chunk_size: int = 1000,
|
||||
parent_chunk_overlap: int = 100,
|
||||
child_chunk_size: int = 200,
|
||||
child_chunk_overlap: int = 20,
|
||||
) -> ParentDocumentRetriever:
|
||||
"""
|
||||
创建 ParentDocumentRetriever 实例。
|
||||
|
||||
Args:
|
||||
collection_name: Qdrant 集合名称,默认 "rag_documents"
|
||||
parent_splitter: 父文档切分器,默认 None(使用默认参数创建)
|
||||
child_splitter: 子文档切分器,默认 None(使用默认参数创建)
|
||||
docstore: 文档存储实例,默认 None(使用默认参数创建)
|
||||
search_k: 检索时返回的结果数,默认 5
|
||||
parent_chunk_size: 父文档块大小,默认 1000
|
||||
parent_chunk_overlap: 父文档块重叠大小,默认 100
|
||||
child_chunk_size: 子文档块大小,默认 200
|
||||
child_chunk_overlap: 子文档块重叠大小,默认 20
|
||||
|
||||
Returns:
|
||||
ParentDocumentRetriever 实例
|
||||
"""
|
||||
# 嵌入模型
|
||||
if embeddings is None:
|
||||
embedder = LlamaCppEmbedder()
|
||||
embeddings = embedder.as_langchain_embeddings()
|
||||
embedder = LlamaCppEmbedder()
|
||||
embeddings = embedder.as_langchain_embeddings()
|
||||
|
||||
# 向量存储(只读)
|
||||
vector_store = QdrantVectorStore(
|
||||
collection_name=collection_name,
|
||||
embeddings=embeddings,
|
||||
)
|
||||
vector_store = QdrantVectorStore(collection_name=collection_name)
|
||||
|
||||
# 切分器(若未提供则创建默认)
|
||||
if parent_splitter is None:
|
||||
@@ -48,7 +56,7 @@ def create_parent_retriever(
|
||||
|
||||
# 文档存储
|
||||
if docstore is None:
|
||||
docstore, _ = create_docstore() # 从环境变量读取连接
|
||||
docstore, _ = create_docstore()
|
||||
|
||||
return ParentDocumentRetriever(
|
||||
vectorstore=vector_store.get_langchain_vectorstore(),
|
||||
@@ -56,4 +64,4 @@ def create_parent_retriever(
|
||||
child_splitter=child_splitter,
|
||||
parent_splitter=parent_splitter,
|
||||
search_kwargs={"k": search_k},
|
||||
)
|
||||
)
|
||||
|
||||
@@ -9,14 +9,13 @@
|
||||
|
||||
>>> # 创建 PostgreSQL 存储
|
||||
>>> store, conn = create_docstore(
|
||||
... connection_string="postgresql://user:pass@host:5432/db",
|
||||
... table_name="parent_docs"
|
||||
... )
|
||||
"""
|
||||
|
||||
|
||||
from .postgres import PostgresDocStore
|
||||
from .factory import create_docstore, get_docstore_uri, DEFAULT_DB_URI
|
||||
from .factory import create_docstore, get_docstore_uri
|
||||
|
||||
__version__ = "2.0.0"
|
||||
|
||||
@@ -27,5 +26,4 @@ __all__ = [
|
||||
# 工厂函数
|
||||
"create_docstore",
|
||||
"get_docstore_uri",
|
||||
"DEFAULT_DB_URI",
|
||||
]
|
||||
|
||||
@@ -5,17 +5,14 @@
|
||||
"""
|
||||
|
||||
import os
|
||||
from ..config import DB_URI, DOCSTORE_URI
|
||||
from ..config import DOCSTORE_URI
|
||||
import logging
|
||||
from typing import Optional, Tuple
|
||||
from typing import Tuple
|
||||
|
||||
from langchain_core.stores import BaseStore
|
||||
from .postgres import PostgresDocStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 默认连接字符串(从环境变量读取)
|
||||
DEFAULT_DB_URI = DB_URI
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_docstore_uri() -> str:
|
||||
@@ -24,48 +21,36 @@ def get_docstore_uri() -> str:
|
||||
|
||||
|
||||
def create_docstore(
|
||||
store_type: str = "postgres",
|
||||
connection_string: Optional[str] = None,
|
||||
table_name: str = "parent_documents",
|
||||
pool_config: Optional[dict] = None,
|
||||
max_concurrency: Optional[int] = None
|
||||
) -> Tuple[BaseStore, Optional[str]]:
|
||||
pool_config: dict | None = None,
|
||||
max_concurrency: int | None = None
|
||||
) -> Tuple[BaseStore, str]:
|
||||
"""
|
||||
工厂函数,创建 PostgreSQL 文档存储。
|
||||
|
||||
|
||||
Args:
|
||||
store_type: 存储类型,目前仅支持 "postgres"(默认)
|
||||
connection_string: PostgreSQL 连接字符串
|
||||
table_name: PostgreSQL 表名(默认:parent_documents)
|
||||
pool_config: 连接池配置
|
||||
max_concurrency: 最大并发操作数,如果为 None 则不限制
|
||||
|
||||
|
||||
Returns:
|
||||
元组 (存储实例, 连接字符串)
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: 不支持的存储类型
|
||||
ImportError: 缺少必要的依赖
|
||||
|
||||
|
||||
Example:
|
||||
>>> # 创建 PostgreSQL 存储
|
||||
>>> store, conn = create_docstore(
|
||||
... connection_string="postgresql://user:pass@host:5432/db",
|
||||
... table_name="parent_docs",
|
||||
... max_concurrency=10
|
||||
... )
|
||||
"""
|
||||
store_type = store_type.lower()
|
||||
|
||||
if store_type == "postgres":
|
||||
conn_str = connection_string or get_docstore_uri()
|
||||
store = PostgresDocStore(
|
||||
connection_string=conn_str,
|
||||
table_name=table_name,
|
||||
pool_config=pool_config,
|
||||
max_concurrency=max_concurrency
|
||||
)
|
||||
return store, conn_str
|
||||
|
||||
else:
|
||||
raise ValueError(f"不支持的存储类型: {store_type}。目前仅支持: postgres")
|
||||
conn_str = get_docstore_uri()
|
||||
store = PostgresDocStore(
|
||||
connection_string=conn_str,
|
||||
table_name=table_name,
|
||||
pool_config=pool_config,
|
||||
max_concurrency=max_concurrency
|
||||
)
|
||||
return store, conn_str
|
||||
|
||||
@@ -4,7 +4,6 @@ Qdrant 向量数据库包装器。
|
||||
|
||||
import logging
|
||||
import os
|
||||
from .config import QDRANT_URL, QDRANT_API_KEY
|
||||
import time
|
||||
from typing import List, Optional, Dict, Any
|
||||
|
||||
@@ -14,31 +13,28 @@ from qdrant_client import QdrantClient
|
||||
from qdrant_client.http.models import Distance, VectorParams
|
||||
from httpx import RemoteProtocolError
|
||||
from qdrant_client.http.exceptions import ResponseHandlingException
|
||||
|
||||
from .client import create_qdrant_client
|
||||
from .embedders import LlamaCppEmbedder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
class QdrantVectorStore:
|
||||
"""Qdrant 向量数据库操作包装器。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
collection_name: str,
|
||||
embeddings: Optional[Any] = None,
|
||||
):
|
||||
def __init__(self, collection_name: str):
|
||||
"""
|
||||
Args:
|
||||
collection_name: Qdrant 集合名称。
|
||||
"""
|
||||
self.collection_name = collection_name
|
||||
self._client: Optional[QdrantClient] = None
|
||||
self._connection_attempts = 0
|
||||
self._last_connection_time: Optional[float] = None
|
||||
|
||||
if embeddings is None:
|
||||
from rag_core.embedders import LlamaCppEmbedder
|
||||
embedder = LlamaCppEmbedder()
|
||||
self.embeddings = embedder.as_langchain_embeddings()
|
||||
else:
|
||||
self.embeddings = embeddings
|
||||
embedder = LlamaCppEmbedder()
|
||||
self.embeddings = embedder.as_langchain_embeddings()
|
||||
|
||||
self.create_collection()
|
||||
|
||||
@@ -92,12 +88,10 @@ class QdrantVectorStore:
|
||||
"client_initialized": self._client is not None,
|
||||
}
|
||||
|
||||
def create_collection(self, vector_size: Optional[int] = None, force_recreate: bool = False):
|
||||
def create_collection(self, force_recreate: bool = False):
|
||||
"""创建集合,设置合适的向量维度。"""
|
||||
if vector_size is None:
|
||||
from rag_core.embedders import LlamaCppEmbedder
|
||||
embedder = LlamaCppEmbedder()
|
||||
vector_size = embedder.get_embedding_dimension()
|
||||
embedder = LlamaCppEmbedder()
|
||||
vector_size = embedder.get_embedding_dimension()
|
||||
|
||||
max_retries = 3
|
||||
base_delay = 2
|
||||
@@ -177,4 +171,4 @@ class QdrantVectorStore:
|
||||
|
||||
def get_qdrant_client(self):
|
||||
"""返回原生 Qdrant 客户端(如需手动管理 collection)"""
|
||||
return self.get_client()
|
||||
return self.get_client()
|
||||
|
||||
@@ -3,39 +3,69 @@ FROM python:3.11-slim
|
||||
WORKDIR /app
|
||||
|
||||
# =============================================================================
|
||||
# 非敏感环境变量(固化在镜像中,无需通过 .env 配置)
|
||||
# 非敏感环境变量(固化在镜像中,可通过 .env 覆盖)
|
||||
# =============================================================================
|
||||
ENV PYTHONPATH=/app
|
||||
|
||||
# llama.cpp 服务配置(本地部署标准端口)
|
||||
# =============================================================================
|
||||
# llama.cpp 服务配置(Docker 部署标准端口映射)
|
||||
# =============================================================================
|
||||
# 主 LLM 服务 - Docker host 端口 18000
|
||||
ENV VLLM_BASE_URL=http://host.docker.internal:18000/v1
|
||||
# Embedding 服务 - Docker host 端口 18001
|
||||
ENV LLAMACPP_EMBEDDING_URL=http://host.docker.internal:18001/v1
|
||||
ENV LLAMACPP_RERENT_URL=http://host.docker.internal:18002/v1
|
||||
# Reranker 服务 - Docker host 端口 18002
|
||||
ENV LLAMACPP_RERANKER_URL=http://host.docker.internal:18002/v1
|
||||
|
||||
# Mem0 记忆层配置
|
||||
# =============================================================================
|
||||
# 数据库 & 向量库配置(非敏感部分)
|
||||
# =============================================================================
|
||||
# PostgreSQL(敏感信息通过 .env 注入)
|
||||
ENV DB_HOST=115.190.121.151
|
||||
ENV DB_PORT=5432
|
||||
ENV DB_USER=postgres
|
||||
ENV DB_NAME=langgraph_db
|
||||
|
||||
# Qdrant(敏感信息通过 .env 注入)
|
||||
ENV QDRANT_URL=http://115.190.121.151:6333
|
||||
ENV QDRANT_COLLECTION_NAME=mem0_user_memories
|
||||
|
||||
# 应用行为配置(可通过 .env 覆盖)
|
||||
# =============================================================================
|
||||
# RAG 索引构建配置(非敏感)
|
||||
# =============================================================================
|
||||
ENV RAG_COLLECTION_NAME=rag_documents
|
||||
ENV RAG_CHUNK_SIZE=500
|
||||
ENV RAG_CHUNK_OVERLAP=50
|
||||
ENV RAG_PARENT_CHUNK_SIZE=1000
|
||||
ENV RAG_CHILD_CHUNK_SIZE=200
|
||||
ENV RAG_PARENT_CHUNK_OVERLAP=100
|
||||
ENV RAG_CHILD_CHUNK_OVERLAP=20
|
||||
ENV RAG_STRATEGY=parent-child
|
||||
ENV RAG_STORAGE_TYPE=postgres
|
||||
|
||||
# =============================================================================
|
||||
# 应用行为配置
|
||||
# =============================================================================
|
||||
ENV MEMORY_SUMMARIZE_INTERVAL=10
|
||||
ENV ENABLE_GRAPH_TRACE=false
|
||||
|
||||
# unstructured 库 spaCy 模型配置
|
||||
ENV UNSTRUCTURED_LANGUAGE=eng
|
||||
ENV SPACY_MODEL=en_core_web_sm
|
||||
|
||||
# 日志配置
|
||||
# =============================================================================
|
||||
# 日志配置(生产环境默认值)
|
||||
# =============================================================================
|
||||
ENV LOG_LEVEL=WARNING
|
||||
ENV DEBUG=false
|
||||
|
||||
# =============================================================================
|
||||
# 安装依赖
|
||||
# =============================================================================
|
||||
# 复制本地模型文件到镜像
|
||||
COPY docker/models/*.whl /tmp/models/
|
||||
# 复制本地模型文件到镜像(如果有)
|
||||
COPY docker/models/*.whl /tmp/models/ 2>/dev/null || true
|
||||
|
||||
# 安装
|
||||
RUN pip install --no-cache-dir /tmp/models/*.whl && \
|
||||
rm -rf /tmp/models
|
||||
# 安装本地模型 wheel(如果有)
|
||||
RUN if [ -n "$(ls -A /tmp/models/ 2>/dev/null)" ]; then \
|
||||
pip install --no-cache-dir /tmp/models/*.whl && \
|
||||
rm -rf /tmp/models; \
|
||||
fi
|
||||
|
||||
# 设置 pip 国内镜像源
|
||||
RUN pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
@@ -54,7 +84,6 @@ COPY backend/ ./
|
||||
# =============================================================================
|
||||
EXPOSE 8079
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 启动命令
|
||||
# =============================================================================
|
||||
|
||||
@@ -1,38 +1,79 @@
|
||||
services:
|
||||
# ⭐ PostgreSQL 和 Qdrant 已迁移到远程服务器 (115.190.121.151)
|
||||
# 不再需要在本地 Docker Compose 中运行这些服务
|
||||
|
||||
backend:
|
||||
build:
|
||||
context: .. # 构建上下文为项目根目录
|
||||
dockerfile: docker/backend/Dockerfile
|
||||
container_name: ai-backend
|
||||
environment:
|
||||
# ⭐ 敏感密钥:通过 .env 注入
|
||||
- ZHIPUAI_API_KEY=${ZHIPUAI_API_KEY}
|
||||
- DEEPSEEK_API_KEY=${DEEPSEEK_API_KEY}
|
||||
- LLAMACPP_API_KEY=${LLAMACPP_API_KEY}
|
||||
# =========================================================================
|
||||
# ⭐ 敏感密钥配置 - 必须通过 .env 文件注入
|
||||
# =========================================================================
|
||||
- ZHIPUAI_API_KEY=${ZHIPUAI_API_KEY:?请在 .env 中配置 ZHIPUAI_API_KEY} # ⭐ 敏感密钥配置
|
||||
- DEEPSEEK_API_KEY=${DEEPSEEK_API_KEY:?请在 .env 中配置 DEEPSEEK_API_KEY} # ⭐ 敏感密钥配置
|
||||
- LLAMACPP_API_KEY=${LLAMACPP_API_KEY:?请在 .env 中配置 LLAMACPP_API_KEY} # ⭐ 敏感密钥配置
|
||||
|
||||
# ⭐ 日志调试配置:通过 .env 注入(支持灵活调整)
|
||||
# =========================================================================
|
||||
# PostgreSQL 数据库配置
|
||||
# =========================================================================
|
||||
- DB_HOST=115.190.121.151
|
||||
- DB_PORT=5432
|
||||
- DB_USER=postgres
|
||||
- DB_PASSWORD=${DB_PASSWORD:?请在 .env 中配置 DB_PASSWORD} # ⭐ 敏感密钥配置
|
||||
- DB_NAME=langgraph_db
|
||||
|
||||
# =========================================================================
|
||||
# Qdrant 向量数据库配置(URL + API密钥 配对)
|
||||
# =========================================================================
|
||||
- QDRANT_URL=http://115.190.121.151:6333
|
||||
- QDRANT_API_KEY=${QDRANT_API_KEY:?请在 .env 中配置 QDRANT_API_KEY} # ⭐ 敏感密钥配置
|
||||
- QDRANT_COLLECTION_NAME=mem0_user_memories
|
||||
|
||||
# =========================================================================
|
||||
# llama.cpp 服务配置(URL + API密钥 配对)
|
||||
# =========================================================================
|
||||
# 主 LLM 服务 (Gemma-4-E2B GGUF) - Docker host 端口 18000
|
||||
- VLLM_BASE_URL=http://host.docker.internal:18000/v1
|
||||
# Embedding 服务 (Qwen3-Embedding-0.6B GGUF) - Docker host 端口 18001
|
||||
- LLAMACPP_EMBEDDING_URL=http://host.docker.internal:18001/v1
|
||||
# Reranker 服务 (bge-reranker-v2-m3) - Docker host 端口 18002
|
||||
- LLAMACPP_RERANKER_URL=http://host.docker.internal:18002/v1
|
||||
|
||||
# =========================================================================
|
||||
# RAG 索引构建配置(非敏感)
|
||||
# =========================================================================
|
||||
- RAG_COLLECTION_NAME=rag_documents
|
||||
- RAG_CHUNK_SIZE=500
|
||||
- RAG_CHUNK_OVERLAP=50
|
||||
- RAG_PARENT_CHUNK_SIZE=1000
|
||||
- RAG_CHILD_CHUNK_SIZE=200
|
||||
- RAG_PARENT_CHUNK_OVERLAP=100
|
||||
- RAG_CHILD_CHUNK_OVERLAP=20
|
||||
- RAG_STRATEGY=parent-child
|
||||
- RAG_STORAGE_TYPE=postgres
|
||||
|
||||
# =========================================================================
|
||||
# 日志调试配置(可通过 .env 覆盖)
|
||||
# =========================================================================
|
||||
- LOG_LEVEL=${LOG_LEVEL:-WARNING}
|
||||
- DEBUG=${DEBUG:-false}
|
||||
- ENABLE_GRAPH_TRACE=${ENABLE_GRAPH_TRACE:-false}
|
||||
|
||||
# ⭐ 基础设施配置:固化在 compose 文件中
|
||||
# PostgreSQL 连接(远程服务器)
|
||||
- DB_URI=postgresql://postgres:huang1998@115.190.121.151:5432/langgraph_db?sslmode=disable
|
||||
|
||||
# Qdrant 向量数据库(远程服务器)
|
||||
- QDRANT_URL=http://115.190.121.151:6333
|
||||
# =========================================================================
|
||||
# 应用行为配置
|
||||
# =========================================================================
|
||||
- MEMORY_SUMMARIZE_INTERVAL=${MEMORY_SUMMARIZE_INTERVAL:-10}
|
||||
|
||||
# =========================================================================
|
||||
# 前端通信地址(Docker 内部网络)
|
||||
# =========================================================================
|
||||
- API_URL=http://backend:8079/chat
|
||||
|
||||
volumes:
|
||||
- ../data/user_docs:/app/data/user_docs # 挂载文档目录
|
||||
- ../logs:/app/logs
|
||||
networks:
|
||||
- ai-network
|
||||
# ⭐ 移除对 postgres 和 qdrant 的依赖
|
||||
# ⭐ 移除对 postgres 和 qdrant 的依赖(使用远程服务)
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- "8079:8079"
|
||||
|
||||
@@ -12,10 +12,10 @@ COPY frontend/requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# 复制前端代码
|
||||
COPY frontend/src/ ./frontend/
|
||||
COPY frontend/src/ ./src/
|
||||
|
||||
# 暴露端口
|
||||
EXPOSE 8501
|
||||
|
||||
# 启动命令
|
||||
CMD ["streamlit", "run", "frontend/frontend_main.py", "--server.port", "8501", "--server.address", "0.0.0.0", "--server.baseUrlPath", "/ai"]
|
||||
CMD ["streamlit", "run", "src/frontend_main.py", "--server.port", "8501", "--server.address", "0.0.0.0", "--server.baseUrlPath", "/ai"]
|
||||
|
||||
30
frontend/run.py
Normal file
30
frontend/run.py
Normal file
@@ -0,0 +1,30 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
前端启动包装器
|
||||
保持相对导入的同时,让 Streamlit 能正常运行
|
||||
本地和容器环境使用相同的启动方式
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
# 添加项目根目录和 backend 目录到 Python 路径
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
backend_dir = os.path.join(project_root, "backend")
|
||||
sys.path.insert(0, project_root)
|
||||
sys.path.insert(0, backend_dir)
|
||||
|
||||
# 现在用正确的方式启动 Streamlit
|
||||
# 我们不直接运行 frontend_main.py,而是先加载它作为模块
|
||||
from streamlit.web import cli as stcli
|
||||
|
||||
# 设置工作目录到项目根
|
||||
os.chdir(project_root)
|
||||
|
||||
# 构建 Streamlit 参数
|
||||
frontend_main = os.path.join(project_root, "frontend", "src", "frontend_main.py")
|
||||
sys.argv = ["streamlit", "run", frontend_main, "--server.port", "8501", "--server.address", "0.0.0.0"]
|
||||
|
||||
# 启动 Streamlit
|
||||
if __name__ == "__main__":
|
||||
stcli.main()
|
||||
@@ -1,4 +1,5 @@
|
||||
"""
|
||||
UI 组件模块
|
||||
包含所有可复用的 Streamlit 组件
|
||||
"""
|
||||
"""
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
前端配置管理模块
|
||||
集中管理所有配置项,支持环境变量覆盖
|
||||
需要类型转换的配置在此处理
|
||||
"""
|
||||
|
||||
import os
|
||||
@@ -12,6 +13,31 @@ from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
|
||||
# ========== 辅助函数:类型转换 ==========
|
||||
def _get_str(key: str) -> str | None:
|
||||
"""获取字符串配置"""
|
||||
return os.getenv(key)
|
||||
|
||||
|
||||
def _get_int(key: str, default: int = 0) -> int:
|
||||
"""获取整数配置,自动转换"""
|
||||
value = os.getenv(key)
|
||||
if value is not None:
|
||||
try:
|
||||
return int(value)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
return default
|
||||
|
||||
|
||||
def _get_bool(key: str, default: bool = False) -> bool:
|
||||
"""获取布尔配置,自动转换"""
|
||||
value = os.getenv(key)
|
||||
if value is not None:
|
||||
return value.lower() in ("true", "1", "yes", "on")
|
||||
return default
|
||||
|
||||
|
||||
@dataclass
|
||||
class FrontendConfig:
|
||||
"""前端配置类 - 统一管理所有配置项"""
|
||||
@@ -19,51 +45,55 @@ class FrontendConfig:
|
||||
# ==================== API 配置 ====================
|
||||
api_base: str = ""
|
||||
|
||||
# ==================== 页面配置 ====================
|
||||
# ==================== 页面配置(固定值,无需环境变量) ====================
|
||||
page_title: str = "AI 个人助手"
|
||||
page_icon: str = "🤖"
|
||||
layout: str = "wide"
|
||||
|
||||
# ==================== 模型配置 ====================
|
||||
default_model: str = "local" # 更改为local作为默认模型
|
||||
# ==================== 模型配置(固定值,无需环境变量) ====================
|
||||
default_model: str = "local"
|
||||
model_options: Optional[dict] = None
|
||||
|
||||
# ==================== 用户配置 ====================
|
||||
# ==================== 用户配置(固定值,无需环境变量) ====================
|
||||
default_user_id: str = "default_user"
|
||||
|
||||
# ==================== 历史记录配置 ====================
|
||||
# ==================== 历史记录配置(固定值,无需环境变量) ====================
|
||||
history_limit: int = 50
|
||||
summary_max_length: int = 30
|
||||
|
||||
# ==================== 流式响应配置 ====================
|
||||
# ==================== 流式响应配置(固定值,无需环境变量) ====================
|
||||
stream_timeout: int = 120
|
||||
|
||||
# ==================== 日志配置 ====================
|
||||
log_level: str = ""
|
||||
debug: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
"""初始化后处理 - 设置默认值和加载环境变量"""
|
||||
if self.model_options is None:
|
||||
self.model_options = {
|
||||
"local": "本地 llama.cpp(Gemma-4)", # 本地模型作为第一个
|
||||
"deepseek": "DeepSeek V3.2(在线)", # DeepSeek 作为中间
|
||||
"zhipu": "智谱 GLM-4.7-Flash(在线)" # GLM-4.7 作为最后一个
|
||||
"local": "本地 llama.cpp(Gemma-4)",
|
||||
"deepseek": "DeepSeek V3.2(在线)",
|
||||
"zhipu": "智谱 GLM-4.7-Flash(在线)"
|
||||
}
|
||||
|
||||
# 从环境变量加载配置
|
||||
# 从环境变量加载配置(优先级最高)
|
||||
self._load_from_env()
|
||||
|
||||
def _load_from_env(self):
|
||||
"""从环境变量加载配置(优先级最高)"""
|
||||
"""从环境变量加载配置(仅加载必要的配置项)"""
|
||||
# API 地址(移除 /chat 后缀)
|
||||
# 优先级:环境变量 API_URL > 默认值
|
||||
api_url = os.getenv("API_URL", "http://127.0.0.1:8079")
|
||||
self.api_base = api_url.replace("/chat", "").rstrip("/")
|
||||
|
||||
api_url = _get_str("API_URL")
|
||||
if api_url:
|
||||
self.api_base = api_url.replace("/chat", "").rstrip("/")
|
||||
|
||||
# 日志配置
|
||||
self.log_level = os.getenv("LOG_LEVEL", "INFO").upper()
|
||||
self.debug = os.getenv("DEBUG", "false").lower() == "true"
|
||||
log_level = _get_str("LOG_LEVEL")
|
||||
if log_level:
|
||||
self.log_level = log_level.upper()
|
||||
|
||||
self.debug = _get_bool("DEBUG", False)
|
||||
|
||||
|
||||
# 日志配置
|
||||
self.log_level = os.getenv("LOG_LEVEL", "INFO").upper()
|
||||
self.debug = os.getenv("DEBUG", "false").lower() == "true"
|
||||
# 全局配置实例(单例模式)
|
||||
config = FrontendConfig()
|
||||
config = FrontendConfig()
|
||||
|
||||
@@ -6,18 +6,25 @@ AI Agent 前端主入口
|
||||
import sys
|
||||
import os
|
||||
|
||||
# 添加项目根目录到 Python 路径,支持绝对导入
|
||||
# 现在的结构: frontend/src/frontend_main.py,所以要获取 frontend/ 目录作为根
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
# 添加当前目录到路径,确保智能导入能工作
|
||||
src_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.insert(0, src_dir)
|
||||
|
||||
import streamlit as st
|
||||
|
||||
# 使用相对导入
|
||||
from .config import config
|
||||
from .state import AppState
|
||||
from .components.sidebar import render_sidebar
|
||||
from .components.chat_area import render_chat_area
|
||||
from .components.info_panel import render_info_panel
|
||||
# 智能导入:作为 __main__ 被 Streamlit 运行时用绝对导入,否则用相对导入
|
||||
if __name__ == '__main__':
|
||||
from config import config
|
||||
from state import AppState
|
||||
from components.sidebar import render_sidebar
|
||||
from components.chat_area import render_chat_area
|
||||
from components.info_panel import render_info_panel
|
||||
else:
|
||||
from .config import config
|
||||
from .state import AppState
|
||||
from .components.sidebar import render_sidebar
|
||||
from .components.chat_area import render_chat_area
|
||||
from .components.info_panel import render_info_panel
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
||||
@@ -26,9 +26,23 @@ Offline RAG Indexer module.
|
||||
from .index_builder import IndexBuilder, IndexBuilderConfig, DocstoreConfig
|
||||
from .loaders import DocumentLoader
|
||||
from .splitters import SplitterType, get_splitter
|
||||
from .config import (
|
||||
QDRANT_URL,
|
||||
QDRANT_API_KEY,
|
||||
LLAMACPP_EMBEDDING_URL,
|
||||
LLAMACPP_API_KEY,
|
||||
DB_URI,
|
||||
DOCSTORE_URI,
|
||||
RAG_OCR_LANGUAGES,
|
||||
RAG_DOC_LANGUAGES,
|
||||
)
|
||||
|
||||
# 从 rag_core 重新导出常用组件
|
||||
from rag_core import (
|
||||
import sys
|
||||
from pathlib import Path
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent / "backend"))
|
||||
|
||||
from backend.rag_core import (
|
||||
LlamaCppEmbedder,
|
||||
QdrantVectorStore,
|
||||
PostgresDocStore,
|
||||
@@ -39,7 +53,7 @@ __version__ = "2.0.0"
|
||||
|
||||
__all__ = [
|
||||
# 核心构建器与配置
|
||||
"index_builder",
|
||||
"IndexBuilder",
|
||||
"IndexBuilderConfig",
|
||||
"DocstoreConfig",
|
||||
|
||||
@@ -50,6 +64,16 @@ __all__ = [
|
||||
"SplitterType",
|
||||
"get_splitter",
|
||||
|
||||
# 配置
|
||||
"QDRANT_URL",
|
||||
"QDRANT_API_KEY",
|
||||
"LLAMACPP_EMBEDDING_URL",
|
||||
"LLAMACPP_API_KEY",
|
||||
"DB_URI",
|
||||
"DOCSTORE_URI",
|
||||
"RAG_OCR_LANGUAGES",
|
||||
"RAG_DOC_LANGUAGES",
|
||||
|
||||
# 嵌入与向量存储
|
||||
"LlamaCppEmbedder",
|
||||
"QdrantVectorStore",
|
||||
|
||||
@@ -6,13 +6,24 @@ import asyncio
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 加载 .env 文件
|
||||
load_dotenv()
|
||||
|
||||
# 添加项目根目录和 backend 目录到 Python 路径
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent / "backend"))
|
||||
|
||||
from .index_builder import IndexBuilder, IndexBuilderConfig
|
||||
from .splitters import SplitterType
|
||||
# 导入方式:条件导入,支持作为脚本运行和作为包导入
|
||||
if __name__ == "__main__":
|
||||
# 作为脚本直接运行时使用绝对导入
|
||||
from rag_indexer.index_builder import IndexBuilder, IndexBuilderConfig
|
||||
from rag_indexer.splitters import SplitterType
|
||||
else:
|
||||
# 作为包导入时使用相对导入
|
||||
from .index_builder import IndexBuilder, IndexBuilderConfig
|
||||
from .splitters import SplitterType
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
|
||||
@@ -1,32 +1,71 @@
|
||||
"""
|
||||
RAG Indexer 配置管理模块
|
||||
集中管理所有环境变量配置项,避免散落在各个文件中
|
||||
所有配置直接从环境变量读取,无默认值,避免配置混乱
|
||||
需要类型转换的配置在此处理
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
# 尝试从 rag_core 导入配置(如果可用)
|
||||
try:
|
||||
from rag_core.config import (
|
||||
QDRANT_URL,
|
||||
QDRANT_API_KEY,
|
||||
LLAMACPP_EMBEDDING_URL,
|
||||
LLAMACPP_API_KEY,
|
||||
DB_URI,
|
||||
DOCSTORE_URI,
|
||||
)
|
||||
except ImportError:
|
||||
# 如果 rag_core 不可用,则直接读取环境变量
|
||||
QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333")
|
||||
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY", "")
|
||||
LLAMACPP_EMBEDDING_URL = os.getenv("LLAMACPP_EMBEDDING_URL", "http://127.0.0.1:8082")
|
||||
LLAMACPP_API_KEY = os.getenv("LLAMACPP_API_KEY", "")
|
||||
DB_URI = os.getenv(
|
||||
"DB_URI",
|
||||
"postgresql://postgres:huang1998@ai-postgres:5432/langgraph_db?sslmode=disable"
|
||||
)
|
||||
DOCSTORE_URI = os.getenv("DOCSTORE_URI", DB_URI)
|
||||
|
||||
# ========== 辅助函数:类型转换 ==========
|
||||
def _get_str(key: str) -> str | None:
|
||||
"""获取字符串配置"""
|
||||
return os.getenv(key)
|
||||
|
||||
|
||||
def _get_int(key: str) -> int | None:
|
||||
"""获取整数配置,自动转换"""
|
||||
value = os.getenv(key)
|
||||
if value is not None:
|
||||
try:
|
||||
return int(value)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def _get_list_str(key: str, default: list[str] | None = None) -> list[str]:
|
||||
"""获取字符串列表配置,从逗号分隔的字符串解析"""
|
||||
value = os.getenv(key)
|
||||
if value is not None:
|
||||
return [item.strip() for item in value.split(",") if item.strip()]
|
||||
return default or []
|
||||
|
||||
|
||||
# ========== 向量数据库配置(URL + API密钥 配对) ==========
|
||||
QDRANT_URL = _get_str("QDRANT_URL")
|
||||
QDRANT_API_KEY = _get_str("QDRANT_API_KEY")
|
||||
|
||||
|
||||
# ========== 嵌入服务配置(URL + API密钥 配对) ==========
|
||||
LLAMACPP_EMBEDDING_URL = _get_str("LLAMACPP_EMBEDDING_URL")
|
||||
LLAMACPP_API_KEY = _get_str("LLAMACPP_API_KEY")
|
||||
|
||||
|
||||
# ========== 文档存储配置(分离配置 + 完整URI) ==========
|
||||
# 分离配置(优先使用)
|
||||
DB_HOST = _get_str("DB_HOST")
|
||||
DB_PORT = _get_int("DB_PORT")
|
||||
DB_USER = _get_str("DB_USER")
|
||||
DB_PASSWORD = _get_str("DB_PASSWORD")
|
||||
DB_NAME = _get_str("DB_NAME")
|
||||
|
||||
# 完整连接字符串(直接从环境变量读取)
|
||||
DB_URI = _get_str("DB_URI")
|
||||
|
||||
# 文档存储 URI(直接从环境变量读取,默认同 DB_URI)
|
||||
DOCSTORE_URI = _get_str("DOCSTORE_URI") or DB_URI
|
||||
|
||||
|
||||
# ========== 文档加载器配置(unstructured 库) ==========
|
||||
# OCR 语言列表(逗号分隔,如 "chi_sim,eng")
|
||||
RAG_OCR_LANGUAGES = _get_list_str("RAG_OCR_LANGUAGES", ["chi_sim", "eng"])
|
||||
|
||||
# 文档主语言列表(逗号分隔,如 "zh")
|
||||
RAG_DOC_LANGUAGES = _get_list_str("RAG_DOC_LANGUAGES", ["zh"])
|
||||
|
||||
|
||||
# ========== 索引器专用配置 ==========
|
||||
# 默认索引存储路径
|
||||
INDEX_STORAGE_PATH = os.getenv("INDEX_STORAGE_PATH", "./index_storage")
|
||||
INDEX_STORAGE_PATH = _get_str("INDEX_STORAGE_PATH")
|
||||
|
||||
@@ -23,6 +23,12 @@ from qdrant_client.http.exceptions import ResponseHandlingException
|
||||
|
||||
from .loaders import DocumentLoader
|
||||
from .splitters import SplitterType, get_splitter
|
||||
|
||||
# 从 rag_core 导入
|
||||
import sys
|
||||
from pathlib import Path
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent / "backend"))
|
||||
|
||||
from rag_core import LlamaCppEmbedder, QdrantVectorStore, create_docstore, create_parent_retriever
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -31,11 +37,10 @@ logger = logging.getLogger(__name__)
|
||||
@dataclass
|
||||
class DocstoreConfig:
|
||||
"""文档存储配置(用于父块存储)。"""
|
||||
connection_string: Optional[str] = None
|
||||
pool_config: Optional[Dict[str, Any]] = None
|
||||
max_concurrency: Optional[int] = None
|
||||
pool_config: Dict[str, Any] | None = None
|
||||
max_concurrency: int | None = None
|
||||
# 若要从外部注入已创建好的 docstore,可直接设置此字段
|
||||
instance: Optional[BaseStore] = None
|
||||
instance: BaseStore | None = None
|
||||
|
||||
@dataclass
|
||||
class IndexBuilderConfig:
|
||||
@@ -89,7 +94,6 @@ class IndexBuilder:
|
||||
# 初始化向量存储
|
||||
self.vector_store = QdrantVectorStore(
|
||||
collection_name=config.collection_name,
|
||||
embeddings=self.embeddings,
|
||||
)
|
||||
|
||||
# 根据切分类型初始化相关组件
|
||||
@@ -141,7 +145,6 @@ class IndexBuilder:
|
||||
# 使用工厂函数创建检索器,避免重复代码
|
||||
self.retriever = create_parent_retriever(
|
||||
collection_name=cfg.collection_name,
|
||||
embeddings=self.embeddings,
|
||||
parent_splitter=self.parent_splitter,
|
||||
child_splitter=self.child_splitter,
|
||||
docstore=self.docstore,
|
||||
@@ -158,7 +161,6 @@ class IndexBuilder:
|
||||
|
||||
# 使用 create_docstore 创建 PostgreSQL 存储
|
||||
docstore, conn_info = create_docstore(
|
||||
connection_string=cfg.connection_string,
|
||||
pool_config=cfg.pool_config,
|
||||
max_concurrency=cfg.max_concurrency,
|
||||
)
|
||||
|
||||
@@ -11,6 +11,9 @@ from langchain_core.documents import Document
|
||||
from unstructured.documents.elements import Element
|
||||
from unstructured.partition.auto import partition
|
||||
|
||||
# 相对导入配置
|
||||
from .config import RAG_OCR_LANGUAGES, RAG_DOC_LANGUAGES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 模块加载时设置一次环境变量,避免重复设置
|
||||
@@ -47,8 +50,8 @@ class DocumentLoader:
|
||||
"""
|
||||
self.extract_images = extract_images
|
||||
self.strategy = strategy
|
||||
self.ocr_languages = ocr_languages or ["chi_sim", "eng"]
|
||||
self.languages = languages or ["zh"]
|
||||
self.ocr_languages = ocr_languages or RAG_OCR_LANGUAGES
|
||||
self.languages = languages or RAG_DOC_LANGUAGES
|
||||
self.include_page_breaks = include_page_breaks
|
||||
self.pdf_infer_table_structure = pdf_infer_table_structure
|
||||
self.partition_kwargs = partition_kwargs or {}
|
||||
|
||||
@@ -1,83 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
测试重构后的 IndexBuilder 和 RAGRetriever
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
|
||||
# 添加项目根目录到 Python 路径
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
|
||||
from ..index_builder import IndexBuilder
|
||||
from ..splitters import SplitterType
|
||||
|
||||
async def test_index_builder():
|
||||
"""测试索引构建功能"""
|
||||
print("测试索引构建功能...")
|
||||
|
||||
# 创建 IndexBuilder 实例
|
||||
builder = IndexBuilder(
|
||||
collection_name="test_collection",
|
||||
splitter_type=SplitterType.PARENT_CHILD,
|
||||
parent_chunk_size=1000,
|
||||
child_chunk_size=200
|
||||
)
|
||||
|
||||
# 测试文档路径
|
||||
test_file = os.path.join(os.path.dirname(__file__), "..", "data", "corpus", "三国演义.txt")
|
||||
|
||||
if os.path.exists(test_file):
|
||||
# 构建索引
|
||||
print(f"正在为文件 {test_file} 构建索引...")
|
||||
processed = await builder.build_from_file(test_file)
|
||||
print(f"索引构建完成,处理了 {processed} 个文档")
|
||||
|
||||
# 获取集合信息
|
||||
info = builder.get_collection_info()
|
||||
print(f"集合信息: {info}")
|
||||
else:
|
||||
print(f"测试文件不存在: {test_file}")
|
||||
|
||||
# 测试搜索功能
|
||||
print("\n测试搜索功能...")
|
||||
try:
|
||||
results = builder.search("吕布", k=3)
|
||||
print(f"搜索结果数量: {len(results)}")
|
||||
for i, result in enumerate(results):
|
||||
print(f"\n结果 {i+1}:")
|
||||
print(f"内容: {result.page_content[:100]}...")
|
||||
except Exception as e:
|
||||
print(f"搜索测试失败: {e}")
|
||||
|
||||
# 测试带父块上下文的搜索
|
||||
print("\n测试带父块上下文的搜索...")
|
||||
try:
|
||||
results = await builder.search_with_parent_context("吕布", k=3)
|
||||
print(f"搜索结果数量: {len(results)}")
|
||||
for i, result in enumerate(results):
|
||||
print(f"\n结果 {i+1}:")
|
||||
print(f"内容: {result.page_content[:100]}...")
|
||||
except Exception as e:
|
||||
print(f"带父块上下文的搜索测试失败: {e}")
|
||||
|
||||
# 测试统一检索接口
|
||||
print("\n测试统一检索接口...")
|
||||
try:
|
||||
# 返回父块
|
||||
results_parent = await builder.retrieve("吕布", return_parent=True)
|
||||
print(f"返回父块的结果数量: {len(results_parent)}")
|
||||
|
||||
# 返回子块
|
||||
results_child = await builder.retrieve("吕布", return_parent=False)
|
||||
print(f"返回子块的结果数量: {len(results_child)}")
|
||||
except Exception as e:
|
||||
print(f"统一检索接口测试失败: {e}")
|
||||
|
||||
# 关闭资源
|
||||
builder.close()
|
||||
print("\n测试完成")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_index_builder())
|
||||
@@ -1,188 +0,0 @@
|
||||
"""
|
||||
验证 RAG 索引完整性。
|
||||
|
||||
检查 Qdrant 向量库、PostgreSQL 文档存储及检索功能。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../.."))
|
||||
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333")
|
||||
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
|
||||
DB_URI = os.getenv("DB_URI", "postgresql://postgres:huang1998@115.190.121.151:5432/langgraph_db?sslmode=disable")
|
||||
COLLECTION_NAME = "rag_documents"
|
||||
TABLE_NAME = "parent_documents"
|
||||
|
||||
|
||||
def check_qdrant():
|
||||
"""检查 Qdrant 向量库。"""
|
||||
from qdrant_client import QdrantClient
|
||||
|
||||
print("=" * 60)
|
||||
print("Qdrant 向量库")
|
||||
print("=" * 60)
|
||||
|
||||
client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)
|
||||
|
||||
# 集合列表
|
||||
collections = client.get_collections().collections
|
||||
print(f"\n集合数: {len(collections)}")
|
||||
for c in collections:
|
||||
print(f" - {c.name}")
|
||||
|
||||
# 目标集合信息
|
||||
if not any(c.name == COLLECTION_NAME for c in collections):
|
||||
print(f"\n集合 '{COLLECTION_NAME}' 不存在")
|
||||
return
|
||||
|
||||
info = client.get_collection(COLLECTION_NAME)
|
||||
print(f"\n集合 '{COLLECTION_NAME}':")
|
||||
print(f" 状态: {info.status}")
|
||||
print(f" 向量数: {info.points_count}")
|
||||
|
||||
vectors_config = info.config.params.vectors
|
||||
if isinstance(vectors_config, dict):
|
||||
for name, vc in vectors_config.items():
|
||||
print(f" 向量 '{name}': 维度={vc.size}, 距离={vc.distance}")
|
||||
else:
|
||||
print(f" 向量维度: {vectors_config.size}")
|
||||
|
||||
# 抽样查看
|
||||
print(f"\n前 3 个向量:")
|
||||
points = client.scroll(
|
||||
collection_name=COLLECTION_NAME,
|
||||
limit=3,
|
||||
with_payload=True,
|
||||
with_vectors=False
|
||||
)
|
||||
for i, point in enumerate(points[0]):
|
||||
print(f"\n {i+1}. ID: {point.id}")
|
||||
payload = point.payload or {}
|
||||
print(f" 内容: {payload.get('page_content', '')[:100]}...")
|
||||
|
||||
|
||||
async def check_postgres():
|
||||
"""检查 PostgreSQL 文档存储。"""
|
||||
import asyncpg
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("PostgreSQL 文档存储")
|
||||
print("=" * 60)
|
||||
|
||||
conn = await asyncpg.connect(dsn=DB_URI)
|
||||
|
||||
try:
|
||||
# 表是否存在
|
||||
tables = await conn.fetch(
|
||||
"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'"
|
||||
)
|
||||
table_names = [t['table_name'] for t in tables]
|
||||
|
||||
if TABLE_NAME not in table_names:
|
||||
print(f"\n表 '{TABLE_NAME}' 不存在")
|
||||
return
|
||||
|
||||
# 统计
|
||||
count = await conn.fetchval(f"SELECT COUNT(*) FROM {TABLE_NAME}")
|
||||
print(f"\n表 '{TABLE_NAME}': {count} 条记录")
|
||||
|
||||
# 抽样
|
||||
print(f"\n前 3 个文档:")
|
||||
rows = await conn.fetch(
|
||||
f"SELECT key, value FROM {TABLE_NAME} ORDER BY key LIMIT 3"
|
||||
)
|
||||
for i, row in enumerate(rows):
|
||||
print(f"\n {i+1}. Key: {row['key']}")
|
||||
val = row['value']
|
||||
if isinstance(val, dict) and 'page_content' in val:
|
||||
print(f" 内容: {val['page_content'][:100]}...")
|
||||
|
||||
# Key 前缀分布
|
||||
key_prefixes = await conn.fetch(
|
||||
f"""
|
||||
SELECT
|
||||
CASE
|
||||
WHEN key LIKE '%:%' THEN split_part(key, ':', 1)
|
||||
ELSE 'no_prefix'
|
||||
END AS prefix,
|
||||
COUNT(*) AS cnt
|
||||
FROM {TABLE_NAME}
|
||||
GROUP BY prefix
|
||||
ORDER BY cnt DESC
|
||||
LIMIT 10
|
||||
"""
|
||||
)
|
||||
print(f"\nKey 前缀分布:")
|
||||
for row in key_prefixes:
|
||||
print(f" {row['prefix']}: {row['cnt']}")
|
||||
|
||||
finally:
|
||||
await conn.close()
|
||||
|
||||
|
||||
async def test_search():
|
||||
"""测试检索功能。"""
|
||||
from rag_indexer.index_builder import IndexBuilder, IndexBuilderConfig
|
||||
from rag_indexer.splitters import SplitterType
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("检索测试")
|
||||
print("=" * 60)
|
||||
|
||||
# 使用配置对象初始化(与默认构建方式一致)
|
||||
config = IndexBuilderConfig(
|
||||
collection_name=COLLECTION_NAME,
|
||||
splitter_type=SplitterType.PARENT_CHILD,
|
||||
)
|
||||
builder = IndexBuilder(config)
|
||||
|
||||
# 确保检索器已初始化
|
||||
if builder.retriever is None:
|
||||
print("错误: 检索器未初始化,请检查切分策略")
|
||||
return
|
||||
|
||||
query = input("\n查询 (回车使用默认): ").strip() or "你好"
|
||||
print(f"\n查询: {query}")
|
||||
|
||||
# 标准检索(返回父块,因为 ParentDocumentRetriever 默认返回父块)
|
||||
print("\n--- 标准检索 (返回父块) ---")
|
||||
results = await builder.retriever.ainvoke(query)
|
||||
for i, doc in enumerate(results):
|
||||
content = doc.page_content[:200] if hasattr(doc, 'page_content') else str(doc)[:200]
|
||||
print(f"\n {i+1}. {content}...")
|
||||
if hasattr(doc, 'metadata'):
|
||||
source = doc.metadata.get('source', '')
|
||||
if source:
|
||||
print(f" 来源: {source}")
|
||||
|
||||
# 若需要仅返回子块,可以临时修改检索器的 search_type
|
||||
# (注意:ParentDocumentRetriever 的 search_type 默认为 "similarity")
|
||||
print("\n--- 检索子块 (通过修改检索器参数) ---")
|
||||
# 创建一个新的检索器副本,设置为返回子块
|
||||
# 简单起见,直接调用 vectorstore 进行相似度搜索获取子块
|
||||
vectorstore = builder.vector_store.get_langchain_vectorstore()
|
||||
sub_results = await vectorstore.asimilarity_search(query, k=3)
|
||||
for i, doc in enumerate(sub_results):
|
||||
content = doc.page_content[:200] if hasattr(doc, 'page_content') else str(doc)[:200]
|
||||
print(f"\n {i+1}. {content}...")
|
||||
if hasattr(doc, 'metadata'):
|
||||
parent_id = doc.metadata.get('parent_id', '')
|
||||
if parent_id:
|
||||
print(f" 父块 ID: {parent_id}")
|
||||
|
||||
|
||||
async def main():
|
||||
check_qdrant()
|
||||
await check_postgres()
|
||||
await test_search()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -290,9 +290,9 @@ start_backend() {
|
||||
source .env 2>/dev/null || true
|
||||
set +a
|
||||
|
||||
export PYTHONPATH="$PROJECT_DIR"
|
||||
export PYTHONPATH="$PROJECT_DIR:$PROJECT_DIR/backend"
|
||||
export BACKEND_PORT=8079
|
||||
python app/backend.py &
|
||||
python backend/app/backend.py &
|
||||
BACKEND_PID=$!
|
||||
echo -e "${GREEN}✓ 后端服务已启动 (PID: $BACKEND_PID)${NC}"
|
||||
sleep 2
|
||||
@@ -307,7 +307,7 @@ start_frontend() {
|
||||
source .env 2>/dev/null || true
|
||||
set +a
|
||||
|
||||
export PYTHONPATH="$PROJECT_DIR"
|
||||
export PYTHONPATH="$PROJECT_DIR:$PROJECT_DIR/backend"
|
||||
streamlit run frontend/src/frontend_main.py &
|
||||
FRONTEND_PID=$!
|
||||
echo -e "${GREEN}✓ 前端服务已启动 (PID: $FRONTEND_PID)${NC}"
|
||||
|
||||
@@ -6,20 +6,22 @@
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from .config import DB_URI
|
||||
import sys
|
||||
import uuid
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 添加项目根目录到 Python 路径 (现在文件在 backend/app/ 下,backend 就是根)
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
|
||||
# 添加项目根目录和 backend 目录到 Python 路径
|
||||
project_root = os.path.join(os.path.dirname(__file__), "..")
|
||||
backend_dir = os.path.join(project_root, "backend")
|
||||
sys.path.insert(0, project_root)
|
||||
sys.path.insert(0, backend_dir)
|
||||
load_dotenv()
|
||||
|
||||
from backend.app.config import DB_URI
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
from ..agent import AIAgentService
|
||||
from ..agent.history import ThreadHistoryService
|
||||
from ..logger import info, warning, error
|
||||
from backend.app.agent.service import AIAgentService
|
||||
from backend.app.agent.history import ThreadHistoryService
|
||||
from backend.app.logger import info, warning, error
|
||||
|
||||
# PostgreSQL 连接字符串
|
||||
|
||||
@@ -5,10 +5,15 @@ import sys
|
||||
import numpy as np
|
||||
from dotenv import load_dotenv
|
||||
from qdrant_client import QdrantClient
|
||||
from backend.rag_core import LlamaCppEmbedder
|
||||
|
||||
# 添加项目根目录和 backend 目录到 Python 路径
|
||||
project_root = os.path.join(os.path.dirname(__file__), "..")
|
||||
backend_dir = os.path.join(project_root, "backend")
|
||||
sys.path.insert(0, project_root)
|
||||
sys.path.insert(0, backend_dir)
|
||||
load_dotenv()
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../.."))
|
||||
|
||||
from rag_core import LlamaCppEmbedder
|
||||
|
||||
QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333")
|
||||
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
|
||||
69
test/test_frontend.py
Normal file
69
test/test_frontend.py
Normal file
@@ -0,0 +1,69 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
前端快速测试脚本
|
||||
验证前端导入是否正常工作
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
# 添加必要的路径
|
||||
project_root = os.path.dirname(os.path.abspath(__file__))
|
||||
frontend_src = os.path.join(project_root, "frontend", "src")
|
||||
backend_dir = os.path.join(project_root, "backend")
|
||||
|
||||
sys.path.insert(0, project_root)
|
||||
sys.path.insert(0, frontend_src)
|
||||
sys.path.insert(0, backend_dir)
|
||||
|
||||
print("=" * 60)
|
||||
print("前端导入测试")
|
||||
print("=" * 60)
|
||||
|
||||
# 测试 1: 直接导入前端模块
|
||||
print("\n[测试 1] 直接导入前端模块...")
|
||||
try:
|
||||
from frontend.src.frontend_main import main
|
||||
print("✅ frontend_main 导入成功")
|
||||
except Exception as e:
|
||||
print(f"❌ 导入失败: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
# 测试 2: 导入配置
|
||||
print("\n[测试 2] 导入配置...")
|
||||
try:
|
||||
from config import config
|
||||
print(f"✅ config 导入成功: page_title={config.page_title}")
|
||||
except Exception as e:
|
||||
print(f"❌ 导入失败: {e}")
|
||||
|
||||
# 测试 3: 导入状态管理
|
||||
print("\n[测试 3] 导入状态管理...")
|
||||
try:
|
||||
from state import AppState
|
||||
print("✅ AppState 导入成功")
|
||||
except Exception as e:
|
||||
print(f"❌ 导入失败: {e}")
|
||||
|
||||
# 测试 4: 导入 API 客户端
|
||||
print("\n[测试 4] 导入 API 客户端...")
|
||||
try:
|
||||
from api_client import api_client
|
||||
print("✅ api_client 导入成功")
|
||||
except Exception as e:
|
||||
print(f"❌ 导入失败: {e}")
|
||||
|
||||
# 测试 5: 导入组件
|
||||
print("\n[测试 5] 导入组件...")
|
||||
try:
|
||||
from components.sidebar import render_sidebar
|
||||
from components.chat_area import render_chat_area
|
||||
from components.info_panel import render_info_panel
|
||||
print("✅ 所有组件导入成功")
|
||||
except Exception as e:
|
||||
print(f"❌ 导入失败: {e}")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("🎉 所有前端导入测试通过!")
|
||||
print("=" * 60)
|
||||
print("\n现在可以使用 ./scripts/start.sh both 启动完整服务")
|
||||
@@ -18,18 +18,14 @@ from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
# 添加项目根目录到路径
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../.."))
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||
from pydantic import SecretStr
|
||||
from langchain_openai import ChatOpenAI
|
||||
from rag_indexer.index_builder import IndexBuilderConfig
|
||||
from rag_indexer.splitters import SplitterType
|
||||
from .pipeline import RAGPipeline
|
||||
from .tools import create_rag_tool_sync
|
||||
from pydantic import SecretStr
|
||||
# 使用本地 LLM(通过 OpenAI 兼容接口)
|
||||
from langchain_openai import ChatOpenAI
|
||||
from rag_core.retriever_factory import create_parent_retriever
|
||||
|
||||
load_dotenv()
|
||||
from backend.app.rag.pipeline import RAGPipeline
|
||||
from backend.app.rag.tools import create_rag_tool_sync
|
||||
from backend.rag_core.retriever_factory import create_parent_retriever
|
||||
|
||||
def create_llm():
|
||||
"""创建本地 vLLM 服务 LLM"""
|
||||
@@ -113,7 +109,7 @@ async def demonstrate_tool_creation():
|
||||
collection_name="rag_documents",
|
||||
splitter_type=SplitterType.PARENT_CHILD,
|
||||
)
|
||||
retriever = retriever = create_parent_retriever(collection_name="rag_documents", search_k=5)
|
||||
retriever = create_parent_retriever(collection_name="rag_documents", search_k=5)
|
||||
|
||||
# 2. 创建 LLM
|
||||
llm = create_llm()
|
||||
49
test/test_rag_indexer_result.py
Normal file
49
test/test_rag_indexer_result.py
Normal file
@@ -0,0 +1,49 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
测试重构后的 IndexBuilder 和 RAGRetriever
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
|
||||
# 添加项目根目录到 Python 路径
|
||||
project_root = os.path.join(os.path.dirname(__file__), "..")
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
from rag_indexer.index_builder import IndexBuilder
|
||||
from rag_indexer.splitters import SplitterType
|
||||
|
||||
async def test_index_builder():
|
||||
"""测试索引构建功能"""
|
||||
print("测试索引构建功能...")
|
||||
|
||||
# 创建 IndexBuilder 实例
|
||||
builder = IndexBuilder(
|
||||
collection_name="test_collection",
|
||||
splitter_type=SplitterType.PARENT_CHILD,
|
||||
parent_chunk_size=1000,
|
||||
child_chunk_size=200
|
||||
)
|
||||
|
||||
# 测试文档路径
|
||||
test_file = os.path.join(os.path.dirname(__file__), "..", "data", "user_docs", "a.txt")
|
||||
|
||||
if os.path.exists(test_file):
|
||||
# 构建索引
|
||||
print(f"正在为文件 {test_file} 构建索引...")
|
||||
processed = await builder.build_from_file(test_file)
|
||||
print(f"索引构建完成,处理了 {processed} 个文档")
|
||||
|
||||
# 获取集合信息
|
||||
info = builder.get_collection_info()
|
||||
print(f"集合信息: {info}")
|
||||
else:
|
||||
print(f"测试文件不存在: {test_file}")
|
||||
|
||||
# 关闭资源
|
||||
builder.close()
|
||||
print("\n测试完成")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_index_builder())
|
||||
Reference in New Issue
Block a user