Compare commits

...

6 Commits

Author SHA1 Message Date
efa8bbcd03 添加配置
Some checks failed
构建并部署 AI Agent 服务 / deploy (push) Failing after 5m11s
2026-04-21 22:07:20 +08:00
aa8072369c 添加配置 2026-04-21 21:55:31 +08:00
5e9bbd519f 测试修改 2026-04-21 20:49:10 +08:00
37e86f3bb1 参数配置统一 2026-04-21 19:06:34 +08:00
e2eaac9498 修改配置 2026-04-21 18:41:14 +08:00
08826c70a3 容器处理 2026-04-21 16:27:05 +08:00
37 changed files with 850 additions and 671 deletions

View File

@@ -1,19 +1,62 @@
# =============================================================================
# Docker Compose 服务器部署配置
# 用法: cp .env.docker .env 然后填入 API Key
# Docker Compose 服务器部署配置模板
# 用法: cp .env.docker .env 然后填入敏感密钥
# =============================================================================
# ⭐ 敏感密钥配置(必须在此配置)
# =============================================================================
# AI 模型 API 密钥
ZHIPUAI_API_KEY=your_zhipuai_api_key_here
DEEPSEEK_API_KEY=your_deepseek_api_key_here
# -----------------------------------------------------------------------------
# AI 模型 API 密钥(⭐ 敏感配置 - 必须填入真实值)
# -----------------------------------------------------------------------------
ZHIPUAI_API_KEY=your_zhipuai_api_key_here # ⭐ 敏感密钥配置
DEEPSEEK_API_KEY=your_deepseek_api_key_here # ⭐ 敏感密钥配置
LLAMACPP_API_KEY=your_llamacpp_api_key_here # ⭐ 敏感密钥配置
# llama.cpp 服务认证 Token与容器启动参数一致
LLAMACPP_API_KEY=token-abc123
# -----------------------------------------------------------------------------
# PostgreSQL 数据库配置(分离配置,易于管理)
# -----------------------------------------------------------------------------
DB_HOST=115.190.121.151
DB_PORT=5432
DB_USER=postgres
DB_PASSWORD=your_db_password_here # ⭐ 敏感密钥配置
DB_NAME=langgraph_db
# 完整连接字符串(也支持直接配置,优先使用分离配置)
DB_URI=postgresql://postgres:${DB_PASSWORD}@115.190.121.151:5432/langgraph_db?sslmode=disable
# ⭐ 日志调试配置(部署时可灵活调整)
# =============================================================================
# -----------------------------------------------------------------------------
# Qdrant 向量数据库配置URL + API密钥 配对)
# -----------------------------------------------------------------------------
QDRANT_URL=http://115.190.121.151:6333
QDRANT_API_KEY=your_qdrant_api_key_here # ⭐ 敏感密钥配置
QDRANT_COLLECTION_NAME=mem0_user_memories
# -----------------------------------------------------------------------------
# llama.cpp 服务配置URL + API密钥 配对)
# -----------------------------------------------------------------------------
# 主 LLM 服务 (Gemma-4-E2B GGUF) - 端口 18000 (Docker host 映射)
VLLM_BASE_URL=http://host.docker.internal:18000/v1
# Embedding 服务 (Qwen3-Embedding-0.6B GGUF) - 端口 18001
LLAMACPP_EMBEDDING_URL=http://host.docker.internal:18001/v1
# LLAMACPP_API_KEY=your_llamacpp_api_key_here (已在上面配置)
# Reranker 服务 (bge-reranker-v2-m3) - 端口 18002
LLAMACPP_RERANKER_URL=http://host.docker.internal:18002/v1
# -----------------------------------------------------------------------------
# RAG 索引构建配置(非敏感,可直接使用)
# -----------------------------------------------------------------------------
RAG_COLLECTION_NAME=rag_documents
RAG_CHUNK_SIZE=500
RAG_CHUNK_OVERLAP=50
RAG_PARENT_CHUNK_SIZE=1000
RAG_CHILD_CHUNK_SIZE=200
RAG_PARENT_CHUNK_OVERLAP=100
RAG_CHILD_CHUNK_OVERLAP=20
RAG_STRATEGY=parent-child
RAG_STORAGE_TYPE=postgres
# -----------------------------------------------------------------------------
# 日志调试配置(部署时可灵活调整)
# -----------------------------------------------------------------------------
# 日志级别DEBUG, INFO, WARNING, ERROR, CRITICAL
# 生产环境推荐 WARNING排查问题时改为 DEBUG
LOG_LEVEL=WARNING
@@ -28,53 +71,13 @@ DEBUG=false
# false: 关闭追踪,减少日志量
ENABLE_GRAPH_TRACE=false
# -----------------------------------------------------------------------------
# llama.cpp 服务配置
# -----------------------------------------------------------------------------
# 主 LLM 服务 (Gemma-4-E2B GGUF) - 端口 8081
VLLM_BASE_URL=http://host.docker.internal:18000/v1
# Embedding 服务 (embeddinggemma-300M GGUF) - 端口 8082
LLAMACPP_EMBEDDING_URL=http://host.docker.internal:18001/v1
# Reranker 服务 (bge-reranker-v2-m3) - 端口 8083
LLAMACPP_RERANKER_URL=http://host.docker.internal:18002/v1
# -----------------------------------------------------------------------------
# Mem0 记忆层配置
# -----------------------------------------------------------------------------
# Qdrant 向量数据库(远程服务器上的独立容器)
QDRANT_URL=http://115.190.121.151:6333
QDRANT_COLLECTION_NAME=mem0_user_memories
# -----------------------------------------------------------------------------
# 数据库配置
# -----------------------------------------------------------------------------
# PostgreSQL 连接字符串(远程服务器上的独立容器)
DB_URI=postgresql://postgres:huang1998@115.190.121.151:5432/langgraph_db?sslmode=disable
# -----------------------------------------------------------------------------
# 前端配置
# -----------------------------------------------------------------------------
# Docker Compose 内部网络,使用服务名 'backend'
API_URL=http://backend:8083/chat
# ⭐ 前端通信地址Docker 内部网络)
# 注意:这里只需要域名和端口,不需要 /chat 路径
- API_URL=http://backend:8083
# -----------------------------------------------------------------------------
# 应用行为配置
# -----------------------------------------------------------------------------
MEMORY_SUMMARIZE_INTERVAL=10
# -----------------------------------------------------------------------------
# unstructured 库 spaCy 模型配置
# 前端配置
# -----------------------------------------------------------------------------
# 指定文档解析使用的语言: eng (英语) 或 zho (中文)
UNSTRUCTURED_LANGUAGE=zho
# 指定 spaCy 模型名称(需与 UNSTRUCTURED_LANGUAGE 对应)
# eng -> en_core_web_sm
# zho -> zh_core_web_sm
SPACY_MODEL=zh_core_web_sm
# Docker Compose 内部网络,使用服务名 'backend'
API_URL=http://backend:8079/chat

2
.gitignore vendored
View File

@@ -17,6 +17,8 @@
!rag_indexer/**
!docker/
!docker/**
!test/
!test/**
!.gitea/
!.gitea/**

View File

@@ -2,7 +2,7 @@
AI Agent 应用模块
"""
from ..agent import AIAgentService
from ..graph.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
from .agent.service import AIAgentService
from .graph.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME
__all__ = ["AIAgentService", "AVAILABLE_TOOLS", "TOOLS_BY_NAME"]

View File

@@ -1,50 +1,91 @@
"""
环境变量集中管理模块
所有配置项统一定义,避免散落在各个文件中
配置分组相关配置放在一起URL 和 API Key 配对
所有配置直接从环境变量读取,无默认值,避免配置混乱
需要类型转换的配置在此处理
"""
import os
from dotenv import load_dotenv
load_dotenv()
# ========== 辅助函数:类型转换 ==========
def _get_str(key: str) -> str | None:
"""获取字符串配置"""
return os.getenv(key)
def _get_int(key: str) -> int | None:
"""获取整数配置,自动转换"""
value = os.getenv(key)
if value is not None:
try:
return int(value)
except (ValueError, TypeError):
pass
return None
def _get_bool(key: str) -> bool | None:
"""获取布尔配置,自动转换"""
value = os.getenv(key)
if value is not None:
return value.lower() in ("true", "1", "yes", "on")
return None
# ========== 第三方 API 密钥 ==========
ZHIPUAI_API_KEY = _get_str("ZHIPUAI_API_KEY")
DEEPSEEK_API_KEY = _get_str("DEEPSEEK_API_KEY")
# ========== llama.cpp 服务配置URL + API密钥 配对) ==========
# 主 LLM 服务
VLLM_BASE_URL = _get_str("VLLM_BASE_URL")
LLM_API_KEY = _get_str("LLAMACPP_API_KEY")
# Embedding 服务 (用于 Mem0 的向量化)
LLAMACPP_EMBEDDING_URL = _get_str("LLAMACPP_EMBEDDING_URL")
LLAMACPP_API_KEY = _get_str("LLAMACPP_API_KEY")
# Reranker 服务
LLAMACPP_RERANKER_URL = _get_str("LLAMACPP_RERANKER_URL")
# ========== Qdrant 向量数据库配置URL + API密钥 配对) ==========
QDRANT_URL = _get_str("QDRANT_URL")
QDRANT_API_KEY = _get_str("QDRANT_API_KEY")
QDRANT_COLLECTION_NAME = _get_str("QDRANT_COLLECTION_NAME")
# ========== PostgreSQL 数据库配置(分离配置 + 完整URI ==========
# 分离配置(优先使用)
DB_HOST = _get_str("DB_HOST")
DB_PORT = _get_int("DB_PORT")
DB_USER = _get_str("DB_USER")
DB_PASSWORD = _get_str("DB_PASSWORD")
DB_NAME = _get_str("DB_NAME")
# 完整连接字符串(直接从环境变量读取)
DB_URI = _get_str("DB_URI")
# ========== 后端服务配置 ==========
BACKEND_PORT = _get_int("BACKEND_PORT")
# ========== Mem0 记忆层配置 ==========
# 记忆提取间隔:每 N 轮对话生成一次摘要
MEMORY_SUMMARIZE_INTERVAL = _get_int("MEMORY_SUMMARIZE_INTERVAL")
# ========== Graph 执行追踪配置 ==========
# 是否启用 Graph 流转追踪(通过环境变量控制)
ENABLE_GRAPH_TRACE = os.getenv("ENABLE_GRAPH_TRACE", "true").lower() == "true"
ENABLE_GRAPH_TRACE = _get_bool("ENABLE_GRAPH_TRACE")
# ========== 记忆提取配置 ==========
# 记忆提取间隔:每 N 轮对话生成一次摘要
MEMORY_SUMMARIZE_INTERVAL = int(os.getenv("MEMORY_SUMMARIZE_INTERVAL", "10"))
# ========== Mem0 记忆层配置 ==========
# Qdrant 向量数据库地址
QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333")
QDRANT_COLLECTION_NAME = os.getenv("QDRANT_COLLECTION_NAME", "mem0_user_memories")
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY", "your-qdrant-api-key")
# ========== llm 配置 ==========
# LLM 模型配置
VLLM_BASE_URL = os.getenv("VLLM_BASE_URL", "http://127.0.0.1:8081/v1")
LLM_API_KEY = os.getenv("LLM_API_KEY", "your-ai-api-key")
# llama.cpp Embedding 服务地址 (用于 Mem0 的向量化)
LLAMACPP_EMBEDDING_URL = os.getenv("LLAMACPP_EMBEDDING_URL", "http://127.0.0.1:8082/v1")
LLAMACPP_API_KEY = os.getenv("LLAMACPP_API_KEY", "your-llamacpp-api-key")
# ========== 后端服务配置 ==========
# 数据库连接字符串
DB_URI = os.getenv(
"DB_URI",
"postgresql://postgres:huang1998@ai-postgres:5432/langgraph_db?sslmode=disable"
)
# 后端服务端口
BACKEND_PORT = int(os.getenv("BACKEND_PORT", "8079"))
# ========== 日志配置 ==========
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
DEBUG = os.getenv("DEBUG", "false").lower() == "true"
# ========== Reranker 服务配置 ==========
LLAMACPP_RERANKER_URL = os.getenv("LLAMACPP_RERANKER_URL", "http://127.0.0.1:8083")
# ========== 第三方 API 密钥 ==========
ZHIPUAI_API_KEY = os.getenv("ZHIPUAI_API_KEY", "")
DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY", "")
LOG_LEVEL = _get_str("LOG_LEVEL")
DEBUG = _get_bool("DEBUG")

View File

@@ -0,0 +1,82 @@
#!/usr/bin/env python3
"""
LangGraph 图结构可视化脚本
快速查看节点和边的连接关系
运行方式python backend/app/graph/visualize_graph.py
"""
import sys
from pathlib import Path
from dotenv import load_dotenv
# 确定项目根目录Agent1 目录)
# 当前文件位置backend/app/graph/visualize_graph.py
# 向上 4 级到 Agent1
PROJECT_ROOT = Path(__file__).parent.parent.parent.parent
BACKEND_DIR = PROJECT_ROOT / "backend"
# 关键:把 backend 目录加入 sys.path这样才能找到 rag_core
# 注意:这只对直接运行脚本有效,对 -m 方式无效(因为 -m 方式在脚本运行前就导入了)
if str(BACKEND_DIR) not in sys.path:
sys.path.insert(0, str(BACKEND_DIR))
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
load_dotenv(PROJECT_ROOT / ".env")
from app.agent.service import AIAgentService
from app.config import DB_URI
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
import asyncio
async def visualize_graph():
"""可视化 LangGraph 结构"""
print("=" * 80)
print(" LangGraph 图结构可视化")
print("=" * 80)
print(f"项目根目录: {PROJECT_ROOT}")
print(f"Backend 目录: {BACKEND_DIR}")
async with AsyncPostgresSaver.from_conn_string(DB_URI) as checkpointer:
await checkpointer.setup()
# 创建服务实例
print("\n正在初始化 Agent 服务...")
agent_service = AIAgentService(checkpointer)
await agent_service.initialize()
for model_name, graph in agent_service.graphs.items():
print(f"\n{'=' * 80}")
print(f" 模型: {model_name}")
print(f"{'=' * 80}")
# 获取图结构
graph_structure = graph.get_graph()
# 1. 直接打印节点和边
print("\n[1] 节点列表:")
print("-" * 80)
for node_id, node in graph_structure.nodes.items():
print(f" - {node_id}: {node.name}")
print("\n[2] 边列表:")
print("-" * 80)
for edge in graph_structure.edges:
print(f" {edge.source} --> {edge.target}")
# 3. ASCII 字符画(需要 grandalf
print("\n[3] ASCII 字符画:")
print("-" * 80)
try:
print(graph_structure.draw_ascii())
except Exception as e:
print(f"⚠️ ASCII 绘制失败: {e}")
# 4. Mermaid 源码
print("\n[4] Mermaid 源码 (可复制到 https://mermaid.live/):")
print("-" * 80)
print(graph_structure.draw_mermaid())
if __name__ == "__main__":
asyncio.run(visualize_graph())

View File

@@ -7,14 +7,6 @@ import os
from .config import LOG_LEVEL, DEBUG
import logging
from typing import Any
from dotenv import load_dotenv
# 先加载环境变量
load_dotenv()
# 从环境变量读取日志级别,默认 INFO
# 根据环境变量控制是否显示详细调试信息
DEBUG_MODE = DEBUG

View File

@@ -37,7 +37,6 @@ RAG 检索与生成模块
from .retriever import (
create_base_retriever,
create_hybrid_retriever,
create_qdrant_client,
)
from .reranker import LLaMaCPPReranker
from .query_transform import MultiQueryGenerator
@@ -50,7 +49,6 @@ __all__ = [
# 检索器工厂函数
"create_base_retriever",
"create_hybrid_retriever",
"create_qdrant_client",
# 重排序器
"LLaMaCPPReranker",

View File

@@ -25,66 +25,25 @@ Qdrant 向量检索器模块
>>> docs = retriever.invoke("什么是 RAG")
"""
from typing import Optional, Dict, Any
from typing import Dict, Any
from qdrant_client import QdrantClient
from qdrant_client.http.exceptions import UnexpectedResponse
from langchain_qdrant import QdrantVectorStore
from langchain_core.embeddings import Embeddings
from langchain_core.retrievers import BaseRetriever
from rag_core import QDRANT_URL, QDRANT_API_KEY
from rag_core import QDRANT_URL, QDRANT_API_KEY, LlamaCppEmbedder
from rag_core.client import create_qdrant_client as create_core_qdrant_client
# 模块级常量
DEFAULT_SEARCH_K = 20
DEFAULT_SCORE_THRESHOLD = 0.3
def create_qdrant_client(
url: Optional[str] = None,
api_key: Optional[str] = None,
timeout: int = 30,
) -> QdrantClient:
"""
创建并返回一个配置好的 Qdrant 客户端。
优先使用传入参数,若未提供则回退到环境变量 QDRANT_URL 和 QDRANT_API_KEY。
Args:
url: Qdrant 服务地址,例如 "http://localhost:6333"
默认从环境变量 QDRANT_URL 读取。
api_key: API 密钥(若 Qdrant 启用了认证)。
默认从环境变量 QDRANT_API_KEY 读取。
timeout: 请求超时时间(秒),默认 30 秒。
Returns:
配置好的 QdrantClient 实例。
Raises:
ValueError: 如果 url 为空且环境变量也未设置。
"""
effective_url = url or QDRANT_URL
if not effective_url:
raise ValueError(
"Qdrant URL 未提供,请设置参数 url 或环境变量 QDRANT_URL"
)
effective_api_key = api_key or QDRANT_API_KEY
client_kwargs = {
"url": effective_url,
"timeout": timeout,
}
if effective_api_key:
client_kwargs["api_key"] = effective_api_key
return QdrantClient(**client_kwargs)
def create_base_retriever(
collection_name: str,
embeddings: Embeddings,
search_kwargs: Optional[Dict[str, Any]] = None,
client: Optional[QdrantClient] = None,
search_kwargs: Dict[str, Any] | None = None,
client: QdrantClient | None = None,
) -> BaseRetriever:
"""
创建基础向量检索器(仅稠密向量检索)。
@@ -94,7 +53,6 @@ def create_base_retriever(
Args:
collection_name: Qdrant 集合名称(需预先创建并索引)。
embeddings: LangChain 兼容的嵌入模型实例。
search_kwargs: 搜索参数,可包含:
- k (int): 返回的文档数量,默认 20。
- score_threshold (float): 相似度阈值,仅返回高于此分数的文档。
@@ -108,6 +66,10 @@ def create_base_retriever(
Raises:
ValueError: 如果集合不存在或嵌入模型无效。
"""
# 嵌入模型
embedder = LlamaCppEmbedder()
embeddings = embedder.as_langchain_embeddings()
# 合并默认搜索参数
merged_search_kwargs = {"k": DEFAULT_SEARCH_K}
if search_kwargs:
@@ -115,7 +77,7 @@ def create_base_retriever(
# 创建或复用 Qdrant 客户端
if client is None:
client = create_qdrant_client()
client = create_core_qdrant_client()
# 验证集合是否存在(可选,便于提前发现问题)
try:
@@ -140,11 +102,10 @@ def create_base_retriever(
def create_hybrid_retriever(
collection_name: str,
embeddings: Embeddings,
dense_k: int = 10,
sparse_k: int = 10,
score_threshold: Optional[float] = DEFAULT_SCORE_THRESHOLD,
client: Optional[QdrantClient] = None,
score_threshold: float | None = DEFAULT_SCORE_THRESHOLD,
client: QdrantClient | None = None,
) -> BaseRetriever:
"""
创建混合检索器(稠密向量 + BM25 稀疏向量)。
@@ -157,7 +118,6 @@ def create_hybrid_retriever(
Args:
collection_name: Qdrant 集合名称。
embeddings: 嵌入模型(用于稠密向量)。
dense_k: 稠密向量检索返回数量,默认 10。
sparse_k: 稀疏向量检索返回数量,默认 10。
score_threshold: 相似度阈值,默认 0.3。
@@ -177,7 +137,6 @@ def create_hybrid_retriever(
# 复用基础检索器创建逻辑,只需调整搜索参数
return create_base_retriever(
collection_name=collection_name,
embeddings=embeddings,
search_kwargs=search_kwargs,
client=client,
)
@@ -186,9 +145,8 @@ def create_hybrid_retriever(
# 可选:提供异步友好的辅助函数
async def acreate_base_retriever(
collection_name: str,
embeddings: Embeddings,
search_kwargs: Optional[Dict[str, Any]] = None,
client: Optional[QdrantClient] = None,
search_kwargs: Dict[str, Any] | None = None,
client: QdrantClient | None = None,
) -> BaseRetriever:
"""
异步创建基础向量检索器(与同步版本功能相同)。
@@ -196,4 +154,4 @@ async def acreate_base_retriever(
适用于需要异步初始化的场景(例如在 FastAPI 启动事件中)。
"""
# 由于 QdrantVectorStore 初始化本身是同步的,这里直接调用同步版本即可
return create_base_retriever(collection_name, embeddings, search_kwargs, client)
return create_base_retriever(collection_name, search_kwargs, client)

View File

@@ -5,9 +5,17 @@ RAG Core - 公共 RAG 组件包
"""
from .embedders import LlamaCppEmbedder
from .vector_store import QdrantVectorStore, QDRANT_URL, QDRANT_API_KEY
from .vector_store import QdrantVectorStore
from .store import PostgresDocStore, create_docstore
from .retriever_factory import create_parent_retriever
from .config import (
QDRANT_URL,
QDRANT_API_KEY,
LLAMACPP_EMBEDDING_URL,
LLAMACPP_API_KEY,
DB_URI,
DOCSTORE_URI,
)
__all__ = [
@@ -15,6 +23,10 @@ __all__ = [
"QdrantVectorStore",
"QDRANT_URL",
"QDRANT_API_KEY",
"LLAMACPP_EMBEDDING_URL",
"LLAMACPP_API_KEY",
"DB_URI",
"DOCSTORE_URI",
"PostgresDocStore",
"create_docstore",
"create_parent_retriever",

View File

@@ -1,27 +1,30 @@
# rag_core/client.py
import os
from .config import QDRANT_URL, QDRANT_API_KEY
from typing import Optional
from qdrant_client import QdrantClient
def create_qdrant_client(timeout: int = 300) -> QdrantClient:
"""
创建并返回一个配置好的 Qdrant 客户端。
def create_qdrant_client(
url: Optional[str] = None,
api_key: Optional[str] = None,
timeout: int = 300, # 索引构建需要较长超时
) -> QdrantClient:
effective_url = url or QDRANT_URL
effective_api_key = api_key or QDRANT_API_KEY
Args:
timeout: 请求超时时间(秒),默认 300 秒(索引构建需要较长超时)。
if not effective_url:
Returns:
配置好的 QdrantClient 实例。
Raises:
ValueError: 如果 QDRANT_URL 未配置。
"""
if not QDRANT_URL:
raise ValueError("Qdrant URL 未配置")
client_kwargs = {
"url": effective_url,
"url": QDRANT_URL,
"timeout": timeout,
}
if effective_api_key:
client_kwargs["api_key"] = effective_api_key
if QDRANT_API_KEY:
client_kwargs["api_key"] = QDRANT_API_KEY
return QdrantClient(**client_kwargs)
return QdrantClient(**client_kwargs)

View File

@@ -1,24 +1,55 @@
"""
RAG Core 配置管理模块
集中管理所有环境变量配置项,避免散落在各个文件中
所有配置直接从环境变量读取,无默认值,避免配置混乱
需要类型转换的配置在此处理
"""
import os
import dotenv
dotenv.load_dotenv()
# ========== 向量数据库配置 ==========
QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333")
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY", "")
# ========== 辅助函数:类型转换 ==========
def _get_str(key: str) -> str | None:
"""获取字符串配置"""
return os.getenv(key)
# ========== 嵌入服务配置 ==========
LLAMACPP_EMBEDDING_URL = os.getenv("LLAMACPP_EMBEDDING_URL", "http://127.0.0.1:8082")
LLAMACPP_API_KEY = os.getenv("LLAMACPP_API_KEY", "")
# ========== 文档存储配置 ==========
DB_URI = os.getenv(
"DB_URI",
"postgresql://postgres:***@ai-postgres:5432/langgraph_db?sslmode=disable"
)
DOCSTORE_URI = os.getenv("DOCSTORE_URI", DB_URI)
def _get_int(key: str) -> int | None:
"""获取整数配置,自动转换"""
value = os.getenv(key)
if value is not None:
try:
return int(value)
except (ValueError, TypeError):
pass
return None
# ========== 向量数据库配置URL + API密钥 配对) ==========
QDRANT_URL = _get_str("QDRANT_URL")
QDRANT_API_KEY = _get_str("QDRANT_API_KEY")
# ========== 嵌入服务配置URL + API密钥 配对) ==========
LLAMACPP_EMBEDDING_URL = _get_str("LLAMACPP_EMBEDDING_URL")
LLAMACPP_API_KEY = _get_str("LLAMACPP_API_KEY")
# ========== 文档存储配置(分离配置 + 完整URI ==========
# 分离配置(优先使用)
DB_HOST = _get_str("DB_HOST")
DB_PORT = _get_int("DB_PORT")
DB_USER = _get_str("DB_USER")
DB_PASSWORD = _get_str("DB_PASSWORD")
DB_NAME = _get_str("DB_NAME")
# 完整连接字符串(直接从环境变量读取)
DB_URI = _get_str("DB_URI")
# 文档存储 URI直接从环境变量读取默认同 DB_URI
DOCSTORE_URI = _get_str("DOCSTORE_URI") or DB_URI
# ========== 其他配置 ==========
# 可以在此添加其他 RAG Core 专用的配置项
# 可以在此添加其他 RAG Core 专用的配置项

View File

@@ -5,22 +5,25 @@
import os
from .config import LLAMACPP_EMBEDDING_URL, LLAMACPP_API_KEY
import httpx
from typing import List, Optional
from typing import List
from langchain_core.embeddings import Embeddings
class LlamaCppEmbedder:
"""通过 OpenAI 兼容 API 封装 llama.cpp 嵌入服务。"""
def __init__(
self,
base_url: Optional[str] = None,
api_key: Optional[str] = None,
model: str = "Qwen3-Embedding-0.6B-Q8_0",
):
self.base_url = base_url or LLAMACPP_EMBEDDING_URL
self.api_key = api_key or LLAMACPP_API_KEY
def __init__(self, model: str = "Qwen3-Embedding-0.6B-Q8_0"):
"""
Args:
model: 嵌入模型名称,默认 "Qwen3-Embedding-0.6B-Q8_0"
"""
self.base_url = LLAMACPP_EMBEDDING_URL
self.api_key = LLAMACPP_API_KEY
self.model = model
print(f"初始化 base_url: { self.base_url}")
def as_langchain_embeddings(self) -> Embeddings:
"""创建 LangChain 兼容的嵌入实例。"""
@@ -30,7 +33,7 @@ class LlamaCppEmbedder:
"""嵌入一批文档。"""
return self._call_embedding_api(texts)
def embed_query(self, text: str) -> List[float]:
def embed_query(self, text: str) -> List[List[float]]:
"""嵌入单个查询。"""
return self._call_embedding_api([text])[0]
@@ -41,13 +44,14 @@ class LlamaCppEmbedder:
def _call_embedding_api(self, texts: List[str]) -> List[List[float]]:
"""直接调用 llama.cpp 嵌入 API。"""
headers = {"Content-Type": "application/json"}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
base = self.base_url.rstrip("/")
if not base.endswith("/v1"):
base = base + "/v1"
headers = {"Content-Type": "application/json"}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
payload = {
"input": texts,
@@ -70,6 +74,7 @@ class LlamaCppEmbedder:
else:
raise ValueError(f"未知的嵌入 API 响应格式: {data}")
class _LlamaCppLangchainAdapter(Embeddings):
"""将 LlamaCppEmbedder 适配为 LangChain Embeddings 接口。"""
@@ -79,5 +84,5 @@ class _LlamaCppLangchainAdapter(Embeddings):
def embed_documents(self, texts: List[str]) -> List[List[float]]:
return self._embedder.embed_documents(texts)
def embed_query(self, text: str) -> List[float]:
return self._embedder.embed_query(text)
def embed_query(self, text: str) -> List[List[float]]:
return self._embedder.embed_query(text)

View File

@@ -1,38 +1,46 @@
# rag_core/retriever_factory.py
# rag_core/retriever_factory.py
from langchain_core.embeddings import Embeddings
from langchain_classic.retrievers import ParentDocumentRetriever
from langchain_text_splitters import RecursiveCharacterTextSplitter
from typing import Optional
from langchain_core.embeddings import Embeddings
from langchain_core.stores import BaseStore
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
from langchain_classic.retrievers import ParentDocumentRetriever
from langchain_core.stores import BaseStore
from rag_core import LlamaCppEmbedder, QdrantVectorStore, create_docstore
def create_parent_retriever(
collection_name: str = "rag_documents",
embeddings: Optional[Embeddings] = None,
parent_splitter: Optional[TextSplitter] = None,
child_splitter: Optional[TextSplitter] = None,
docstore: Optional[BaseStore] = None,
parent_splitter: TextSplitter | None = None,
child_splitter: TextSplitter | None = None,
docstore: BaseStore | None = None,
search_k: int = 5,
# 若未传入切分器,则用以下参数创建默认切分器
parent_chunk_size: int = 1000,
parent_chunk_overlap: int = 100,
child_chunk_size: int = 200,
child_chunk_overlap: int = 20,
) -> ParentDocumentRetriever:
"""
创建 ParentDocumentRetriever 实例。
Args:
collection_name: Qdrant 集合名称,默认 "rag_documents"
parent_splitter: 父文档切分器,默认 None使用默认参数创建
child_splitter: 子文档切分器,默认 None使用默认参数创建
docstore: 文档存储实例,默认 None使用默认参数创建
search_k: 检索时返回的结果数,默认 5
parent_chunk_size: 父文档块大小,默认 1000
parent_chunk_overlap: 父文档块重叠大小,默认 100
child_chunk_size: 子文档块大小,默认 200
child_chunk_overlap: 子文档块重叠大小,默认 20
Returns:
ParentDocumentRetriever 实例
"""
# 嵌入模型
if embeddings is None:
embedder = LlamaCppEmbedder()
embeddings = embedder.as_langchain_embeddings()
embedder = LlamaCppEmbedder()
embeddings = embedder.as_langchain_embeddings()
# 向量存储(只读)
vector_store = QdrantVectorStore(
collection_name=collection_name,
embeddings=embeddings,
)
vector_store = QdrantVectorStore(collection_name=collection_name)
# 切分器(若未提供则创建默认)
if parent_splitter is None:
@@ -48,7 +56,7 @@ def create_parent_retriever(
# 文档存储
if docstore is None:
docstore, _ = create_docstore() # 从环境变量读取连接
docstore, _ = create_docstore()
return ParentDocumentRetriever(
vectorstore=vector_store.get_langchain_vectorstore(),
@@ -56,4 +64,4 @@ def create_parent_retriever(
child_splitter=child_splitter,
parent_splitter=parent_splitter,
search_kwargs={"k": search_k},
)
)

View File

@@ -9,14 +9,13 @@
>>> # 创建 PostgreSQL 存储
>>> store, conn = create_docstore(
... connection_string="postgresql://user:pass@host:5432/db",
... table_name="parent_docs"
... )
"""
from .postgres import PostgresDocStore
from .factory import create_docstore, get_docstore_uri, DEFAULT_DB_URI
from .factory import create_docstore, get_docstore_uri
__version__ = "2.0.0"
@@ -27,5 +26,4 @@ __all__ = [
# 工厂函数
"create_docstore",
"get_docstore_uri",
"DEFAULT_DB_URI",
]

View File

@@ -5,17 +5,14 @@
"""
import os
from ..config import DB_URI, DOCSTORE_URI
from ..config import DOCSTORE_URI
import logging
from typing import Optional, Tuple
from typing import Tuple
from langchain_core.stores import BaseStore
from .postgres import PostgresDocStore
logger = logging.getLogger(__name__)
# 默认连接字符串(从环境变量读取)
DEFAULT_DB_URI = DB_URI
logger = logging.getLogger(__name__)
def get_docstore_uri() -> str:
@@ -24,48 +21,36 @@ def get_docstore_uri() -> str:
def create_docstore(
store_type: str = "postgres",
connection_string: Optional[str] = None,
table_name: str = "parent_documents",
pool_config: Optional[dict] = None,
max_concurrency: Optional[int] = None
) -> Tuple[BaseStore, Optional[str]]:
pool_config: dict | None = None,
max_concurrency: int | None = None
) -> Tuple[BaseStore, str]:
"""
工厂函数,创建 PostgreSQL 文档存储。
Args:
store_type: 存储类型,目前仅支持 "postgres"(默认)
connection_string: PostgreSQL 连接字符串
table_name: PostgreSQL 表名默认parent_documents
pool_config: 连接池配置
max_concurrency: 最大并发操作数,如果为 None 则不限制
Returns:
元组 (存储实例, 连接字符串)
Raises:
ValueError: 不支持的存储类型
ImportError: 缺少必要的依赖
Example:
>>> # 创建 PostgreSQL 存储
>>> store, conn = create_docstore(
... connection_string="postgresql://user:pass@host:5432/db",
... table_name="parent_docs",
... max_concurrency=10
... )
"""
store_type = store_type.lower()
if store_type == "postgres":
conn_str = connection_string or get_docstore_uri()
store = PostgresDocStore(
connection_string=conn_str,
table_name=table_name,
pool_config=pool_config,
max_concurrency=max_concurrency
)
return store, conn_str
else:
raise ValueError(f"不支持的存储类型: {store_type}。目前仅支持: postgres")
conn_str = get_docstore_uri()
store = PostgresDocStore(
connection_string=conn_str,
table_name=table_name,
pool_config=pool_config,
max_concurrency=max_concurrency
)
return store, conn_str

View File

@@ -4,7 +4,6 @@ Qdrant 向量数据库包装器。
import logging
import os
from .config import QDRANT_URL, QDRANT_API_KEY
import time
from typing import List, Optional, Dict, Any
@@ -14,31 +13,28 @@ from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams
from httpx import RemoteProtocolError
from qdrant_client.http.exceptions import ResponseHandlingException
from .client import create_qdrant_client
from .embedders import LlamaCppEmbedder
logger = logging.getLogger(__name__)
class QdrantVectorStore:
"""Qdrant 向量数据库操作包装器。"""
def __init__(
self,
collection_name: str,
embeddings: Optional[Any] = None,
):
def __init__(self, collection_name: str):
"""
Args:
collection_name: Qdrant 集合名称。
"""
self.collection_name = collection_name
self._client: Optional[QdrantClient] = None
self._connection_attempts = 0
self._last_connection_time: Optional[float] = None
if embeddings is None:
from rag_core.embedders import LlamaCppEmbedder
embedder = LlamaCppEmbedder()
self.embeddings = embedder.as_langchain_embeddings()
else:
self.embeddings = embeddings
embedder = LlamaCppEmbedder()
self.embeddings = embedder.as_langchain_embeddings()
self.create_collection()
@@ -92,12 +88,10 @@ class QdrantVectorStore:
"client_initialized": self._client is not None,
}
def create_collection(self, vector_size: Optional[int] = None, force_recreate: bool = False):
def create_collection(self, force_recreate: bool = False):
"""创建集合,设置合适的向量维度。"""
if vector_size is None:
from rag_core.embedders import LlamaCppEmbedder
embedder = LlamaCppEmbedder()
vector_size = embedder.get_embedding_dimension()
embedder = LlamaCppEmbedder()
vector_size = embedder.get_embedding_dimension()
max_retries = 3
base_delay = 2
@@ -177,4 +171,4 @@ class QdrantVectorStore:
def get_qdrant_client(self):
"""返回原生 Qdrant 客户端(如需手动管理 collection"""
return self.get_client()
return self.get_client()

View File

@@ -3,39 +3,69 @@ FROM python:3.11-slim
WORKDIR /app
# =============================================================================
# 非敏感环境变量(固化在镜像中,无需通过 .env 配置
# 非敏感环境变量(固化在镜像中,通过 .env 覆盖
# =============================================================================
ENV PYTHONPATH=/app
# llama.cpp 服务配置(本地部署标准端口)
# =============================================================================
# llama.cpp 服务配置Docker 部署标准端口映射)
# =============================================================================
# 主 LLM 服务 - Docker host 端口 18000
ENV VLLM_BASE_URL=http://host.docker.internal:18000/v1
# Embedding 服务 - Docker host 端口 18001
ENV LLAMACPP_EMBEDDING_URL=http://host.docker.internal:18001/v1
ENV LLAMACPP_RERENT_URL=http://host.docker.internal:18002/v1
# Reranker 服务 - Docker host 端口 18002
ENV LLAMACPP_RERANKER_URL=http://host.docker.internal:18002/v1
# Mem0 记忆层配置
# =============================================================================
# 数据库 & 向量库配置(非敏感部分)
# =============================================================================
# PostgreSQL敏感信息通过 .env 注入)
ENV DB_HOST=115.190.121.151
ENV DB_PORT=5432
ENV DB_USER=postgres
ENV DB_NAME=langgraph_db
# Qdrant敏感信息通过 .env 注入)
ENV QDRANT_URL=http://115.190.121.151:6333
ENV QDRANT_COLLECTION_NAME=mem0_user_memories
# 应用行为配置(可通过 .env 覆盖)
# =============================================================================
# RAG 索引构建配置(非敏感)
# =============================================================================
ENV RAG_COLLECTION_NAME=rag_documents
ENV RAG_CHUNK_SIZE=500
ENV RAG_CHUNK_OVERLAP=50
ENV RAG_PARENT_CHUNK_SIZE=1000
ENV RAG_CHILD_CHUNK_SIZE=200
ENV RAG_PARENT_CHUNK_OVERLAP=100
ENV RAG_CHILD_CHUNK_OVERLAP=20
ENV RAG_STRATEGY=parent-child
ENV RAG_STORAGE_TYPE=postgres
# =============================================================================
# 应用行为配置
# =============================================================================
ENV MEMORY_SUMMARIZE_INTERVAL=10
ENV ENABLE_GRAPH_TRACE=false
# unstructured 库 spaCy 模型配置
ENV UNSTRUCTURED_LANGUAGE=eng
ENV SPACY_MODEL=en_core_web_sm
# 日志配置
# =============================================================================
# 日志配置(生产环境默认值)
# =============================================================================
ENV LOG_LEVEL=WARNING
ENV DEBUG=false
# =============================================================================
# 安装依赖
# =============================================================================
# 复制本地模型文件到镜像
COPY docker/models/*.whl /tmp/models/
# 复制本地模型文件到镜像(如果有)
COPY docker/models/*.whl /tmp/models/ 2>/dev/null || true
# 安装
RUN pip install --no-cache-dir /tmp/models/*.whl && \
rm -rf /tmp/models
# 安装本地模型 wheel如果有
RUN if [ -n "$(ls -A /tmp/models/ 2>/dev/null)" ]; then \
pip install --no-cache-dir /tmp/models/*.whl && \
rm -rf /tmp/models; \
fi
# 设置 pip 国内镜像源
RUN pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
@@ -54,7 +84,6 @@ COPY backend/ ./
# =============================================================================
EXPOSE 8079
# =============================================================================
# 启动命令
# =============================================================================

View File

@@ -1,38 +1,79 @@
services:
# ⭐ PostgreSQL 和 Qdrant 已迁移到远程服务器 (115.190.121.151)
# 不再需要在本地 Docker Compose 中运行这些服务
backend:
build:
context: .. # 构建上下文为项目根目录
dockerfile: docker/backend/Dockerfile
container_name: ai-backend
environment:
# ⭐ 敏感密钥:通过 .env 注入
- ZHIPUAI_API_KEY=${ZHIPUAI_API_KEY}
- DEEPSEEK_API_KEY=${DEEPSEEK_API_KEY}
- LLAMACPP_API_KEY=${LLAMACPP_API_KEY}
# =========================================================================
# ⭐ 敏感密钥配置 - 必须通过 .env 文件注入
# =========================================================================
- ZHIPUAI_API_KEY=${ZHIPUAI_API_KEY:?请在 .env 中配置 ZHIPUAI_API_KEY} # ⭐ 敏感密钥配置
- DEEPSEEK_API_KEY=${DEEPSEEK_API_KEY:?请在 .env 中配置 DEEPSEEK_API_KEY} # ⭐ 敏感密钥配置
- LLAMACPP_API_KEY=${LLAMACPP_API_KEY:?请在 .env 中配置 LLAMACPP_API_KEY} # ⭐ 敏感密钥配置
# ⭐ 日志调试配置:通过 .env 注入(支持灵活调整)
# =========================================================================
# PostgreSQL 数据库配置
# =========================================================================
- DB_HOST=115.190.121.151
- DB_PORT=5432
- DB_USER=postgres
- DB_PASSWORD=${DB_PASSWORD:?请在 .env 中配置 DB_PASSWORD} # ⭐ 敏感密钥配置
- DB_NAME=langgraph_db
# =========================================================================
# Qdrant 向量数据库配置URL + API密钥 配对)
# =========================================================================
- QDRANT_URL=http://115.190.121.151:6333
- QDRANT_API_KEY=${QDRANT_API_KEY:?请在 .env 中配置 QDRANT_API_KEY} # ⭐ 敏感密钥配置
- QDRANT_COLLECTION_NAME=mem0_user_memories
# =========================================================================
# llama.cpp 服务配置URL + API密钥 配对)
# =========================================================================
# 主 LLM 服务 (Gemma-4-E2B GGUF) - Docker host 端口 18000
- VLLM_BASE_URL=http://host.docker.internal:18000/v1
# Embedding 服务 (Qwen3-Embedding-0.6B GGUF) - Docker host 端口 18001
- LLAMACPP_EMBEDDING_URL=http://host.docker.internal:18001/v1
# Reranker 服务 (bge-reranker-v2-m3) - Docker host 端口 18002
- LLAMACPP_RERANKER_URL=http://host.docker.internal:18002/v1
# =========================================================================
# RAG 索引构建配置(非敏感)
# =========================================================================
- RAG_COLLECTION_NAME=rag_documents
- RAG_CHUNK_SIZE=500
- RAG_CHUNK_OVERLAP=50
- RAG_PARENT_CHUNK_SIZE=1000
- RAG_CHILD_CHUNK_SIZE=200
- RAG_PARENT_CHUNK_OVERLAP=100
- RAG_CHILD_CHUNK_OVERLAP=20
- RAG_STRATEGY=parent-child
- RAG_STORAGE_TYPE=postgres
# =========================================================================
# 日志调试配置(可通过 .env 覆盖)
# =========================================================================
- LOG_LEVEL=${LOG_LEVEL:-WARNING}
- DEBUG=${DEBUG:-false}
- ENABLE_GRAPH_TRACE=${ENABLE_GRAPH_TRACE:-false}
# ⭐ 基础设施配置:固化在 compose 文件中
# PostgreSQL 连接(远程服务器)
- DB_URI=postgresql://postgres:huang1998@115.190.121.151:5432/langgraph_db?sslmode=disable
# Qdrant 向量数据库(远程服务器)
- QDRANT_URL=http://115.190.121.151:6333
# =========================================================================
# 应用行为配置
# =========================================================================
- MEMORY_SUMMARIZE_INTERVAL=${MEMORY_SUMMARIZE_INTERVAL:-10}
# =========================================================================
# 前端通信地址Docker 内部网络)
# =========================================================================
- API_URL=http://backend:8079/chat
volumes:
- ../data/user_docs:/app/data/user_docs # 挂载文档目录
- ../logs:/app/logs
networks:
- ai-network
# ⭐ 移除对 postgres 和 qdrant 的依赖
# ⭐ 移除对 postgres 和 qdrant 的依赖(使用远程服务)
restart: unless-stopped
ports:
- "8079:8079"

View File

@@ -12,10 +12,10 @@ COPY frontend/requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# 复制前端代码
COPY frontend/src/ ./frontend/
COPY frontend/src/ ./src/
# 暴露端口
EXPOSE 8501
# 启动命令
CMD ["streamlit", "run", "frontend/frontend_main.py", "--server.port", "8501", "--server.address", "0.0.0.0", "--server.baseUrlPath", "/ai"]
CMD ["streamlit", "run", "src/frontend_main.py", "--server.port", "8501", "--server.address", "0.0.0.0", "--server.baseUrlPath", "/ai"]

30
frontend/run.py Normal file
View File

@@ -0,0 +1,30 @@
#!/usr/bin/env python3
"""
前端启动包装器
保持相对导入的同时,让 Streamlit 能正常运行
本地和容器环境使用相同的启动方式
"""
import sys
import os
# 添加项目根目录和 backend 目录到 Python 路径
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
backend_dir = os.path.join(project_root, "backend")
sys.path.insert(0, project_root)
sys.path.insert(0, backend_dir)
# 现在用正确的方式启动 Streamlit
# 我们不直接运行 frontend_main.py而是先加载它作为模块
from streamlit.web import cli as stcli
# 设置工作目录到项目根
os.chdir(project_root)
# 构建 Streamlit 参数
frontend_main = os.path.join(project_root, "frontend", "src", "frontend_main.py")
sys.argv = ["streamlit", "run", frontend_main, "--server.port", "8501", "--server.address", "0.0.0.0"]
# 启动 Streamlit
if __name__ == "__main__":
stcli.main()

View File

@@ -1,4 +1,5 @@
"""
UI 组件模块
包含所有可复用的 Streamlit 组件
"""
"""

View File

@@ -1,6 +1,7 @@
"""
前端配置管理模块
集中管理所有配置项,支持环境变量覆盖
需要类型转换的配置在此处理
"""
import os
@@ -12,6 +13,31 @@ from dotenv import load_dotenv
load_dotenv()
# ========== 辅助函数:类型转换 ==========
def _get_str(key: str) -> str | None:
"""获取字符串配置"""
return os.getenv(key)
def _get_int(key: str, default: int = 0) -> int:
"""获取整数配置,自动转换"""
value = os.getenv(key)
if value is not None:
try:
return int(value)
except (ValueError, TypeError):
pass
return default
def _get_bool(key: str, default: bool = False) -> bool:
"""获取布尔配置,自动转换"""
value = os.getenv(key)
if value is not None:
return value.lower() in ("true", "1", "yes", "on")
return default
@dataclass
class FrontendConfig:
"""前端配置类 - 统一管理所有配置项"""
@@ -19,51 +45,55 @@ class FrontendConfig:
# ==================== API 配置 ====================
api_base: str = ""
# ==================== 页面配置 ====================
# ==================== 页面配置(固定值,无需环境变量) ====================
page_title: str = "AI 个人助手"
page_icon: str = "🤖"
layout: str = "wide"
# ==================== 模型配置 ====================
default_model: str = "local" # 更改为local作为默认模型
# ==================== 模型配置(固定值,无需环境变量) ====================
default_model: str = "local"
model_options: Optional[dict] = None
# ==================== 用户配置 ====================
# ==================== 用户配置(固定值,无需环境变量) ====================
default_user_id: str = "default_user"
# ==================== 历史记录配置 ====================
# ==================== 历史记录配置(固定值,无需环境变量) ====================
history_limit: int = 50
summary_max_length: int = 30
# ==================== 流式响应配置 ====================
# ==================== 流式响应配置(固定值,无需环境变量) ====================
stream_timeout: int = 120
# ==================== 日志配置 ====================
log_level: str = ""
debug: bool = False
def __post_init__(self):
"""初始化后处理 - 设置默认值和加载环境变量"""
if self.model_options is None:
self.model_options = {
"local": "本地 llama.cppGemma-4", # 本地模型作为第一个
"deepseek": "DeepSeek V3.2(在线)", # DeepSeek 作为中间
"zhipu": "智谱 GLM-4.7-Flash在线" # GLM-4.7 作为最后一个
"local": "本地 llama.cppGemma-4",
"deepseek": "DeepSeek V3.2(在线)",
"zhipu": "智谱 GLM-4.7-Flash在线"
}
# 从环境变量加载配置
# 从环境变量加载配置(优先级最高)
self._load_from_env()
def _load_from_env(self):
"""从环境变量加载配置(优先级最高"""
"""从环境变量加载配置(仅加载必要的配置项"""
# API 地址(移除 /chat 后缀)
# 优先级:环境变量 API_URL > 默认值
api_url = os.getenv("API_URL", "http://127.0.0.1:8079")
self.api_base = api_url.replace("/chat", "").rstrip("/")
api_url = _get_str("API_URL")
if api_url:
self.api_base = api_url.replace("/chat", "").rstrip("/")
# 日志配置
self.log_level = os.getenv("LOG_LEVEL", "INFO").upper()
self.debug = os.getenv("DEBUG", "false").lower() == "true"
log_level = _get_str("LOG_LEVEL")
if log_level:
self.log_level = log_level.upper()
self.debug = _get_bool("DEBUG", False)
# 日志配置
self.log_level = os.getenv("LOG_LEVEL", "INFO").upper()
self.debug = os.getenv("DEBUG", "false").lower() == "true"
# 全局配置实例(单例模式)
config = FrontendConfig()
config = FrontendConfig()

View File

@@ -6,18 +6,25 @@ AI Agent 前端主入口
import sys
import os
# 添加项目根目录到 Python 路径,支持绝对导入
# 现在的结构: frontend/src/frontend_main.py所以要获取 frontend/ 目录作为根
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
# 添加当前目录到路径,确保智能导入能工作
src_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, src_dir)
import streamlit as st
# 使用相对导入
from .config import config
from .state import AppState
from .components.sidebar import render_sidebar
from .components.chat_area import render_chat_area
from .components.info_panel import render_info_panel
# 智能导入:作为 __main__ 被 Streamlit 运行时用绝对导入,否则用相对导入
if __name__ == '__main__':
from config import config
from state import AppState
from components.sidebar import render_sidebar
from components.chat_area import render_chat_area
from components.info_panel import render_info_panel
else:
from .config import config
from .state import AppState
from .components.sidebar import render_sidebar
from .components.chat_area import render_chat_area
from .components.info_panel import render_info_panel
# =============================================================================

View File

@@ -26,9 +26,23 @@ Offline RAG Indexer module.
from .index_builder import IndexBuilder, IndexBuilderConfig, DocstoreConfig
from .loaders import DocumentLoader
from .splitters import SplitterType, get_splitter
from .config import (
QDRANT_URL,
QDRANT_API_KEY,
LLAMACPP_EMBEDDING_URL,
LLAMACPP_API_KEY,
DB_URI,
DOCSTORE_URI,
RAG_OCR_LANGUAGES,
RAG_DOC_LANGUAGES,
)
# 从 rag_core 重新导出常用组件
from rag_core import (
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent / "backend"))
from backend.rag_core import (
LlamaCppEmbedder,
QdrantVectorStore,
PostgresDocStore,
@@ -39,7 +53,7 @@ __version__ = "2.0.0"
__all__ = [
# 核心构建器与配置
"index_builder",
"IndexBuilder",
"IndexBuilderConfig",
"DocstoreConfig",
@@ -50,6 +64,16 @@ __all__ = [
"SplitterType",
"get_splitter",
# 配置
"QDRANT_URL",
"QDRANT_API_KEY",
"LLAMACPP_EMBEDDING_URL",
"LLAMACPP_API_KEY",
"DB_URI",
"DOCSTORE_URI",
"RAG_OCR_LANGUAGES",
"RAG_DOC_LANGUAGES",
# 嵌入与向量存储
"LlamaCppEmbedder",
"QdrantVectorStore",

View File

@@ -6,13 +6,24 @@ import asyncio
import logging
import sys
from pathlib import Path
from dotenv import load_dotenv
# 加载 .env 文件
load_dotenv()
# 添加项目根目录和 backend 目录到 Python 路径
sys.path.insert(0, str(Path(__file__).parent.parent))
sys.path.insert(0, str(Path(__file__).parent.parent / "backend"))
from .index_builder import IndexBuilder, IndexBuilderConfig
from .splitters import SplitterType
# 导入方式:条件导入,支持作为脚本运行和作为包导入
if __name__ == "__main__":
# 作为脚本直接运行时使用绝对导入
from rag_indexer.index_builder import IndexBuilder, IndexBuilderConfig
from rag_indexer.splitters import SplitterType
else:
# 作为包导入时使用相对导入
from .index_builder import IndexBuilder, IndexBuilderConfig
from .splitters import SplitterType
logging.basicConfig(
level=logging.INFO,

View File

@@ -1,32 +1,71 @@
"""
RAG Indexer 配置管理模块
集中管理所有环境变量配置项,避免散落在各个文件中
所有配置直接从环境变量读取,无默认值,避免配置混乱
需要类型转换的配置在此处理
"""
import os
# 尝试从 rag_core 导入配置(如果可用)
try:
from rag_core.config import (
QDRANT_URL,
QDRANT_API_KEY,
LLAMACPP_EMBEDDING_URL,
LLAMACPP_API_KEY,
DB_URI,
DOCSTORE_URI,
)
except ImportError:
# 如果 rag_core 不可用,则直接读取环境变量
QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333")
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY", "")
LLAMACPP_EMBEDDING_URL = os.getenv("LLAMACPP_EMBEDDING_URL", "http://127.0.0.1:8082")
LLAMACPP_API_KEY = os.getenv("LLAMACPP_API_KEY", "")
DB_URI = os.getenv(
"DB_URI",
"postgresql://postgres:huang1998@ai-postgres:5432/langgraph_db?sslmode=disable"
)
DOCSTORE_URI = os.getenv("DOCSTORE_URI", DB_URI)
# ========== 辅助函数:类型转换 ==========
def _get_str(key: str) -> str | None:
"""获取字符串配置"""
return os.getenv(key)
def _get_int(key: str) -> int | None:
"""获取整数配置,自动转换"""
value = os.getenv(key)
if value is not None:
try:
return int(value)
except (ValueError, TypeError):
pass
return None
def _get_list_str(key: str, default: list[str] | None = None) -> list[str]:
"""获取字符串列表配置,从逗号分隔的字符串解析"""
value = os.getenv(key)
if value is not None:
return [item.strip() for item in value.split(",") if item.strip()]
return default or []
# ========== 向量数据库配置URL + API密钥 配对) ==========
QDRANT_URL = _get_str("QDRANT_URL")
QDRANT_API_KEY = _get_str("QDRANT_API_KEY")
# ========== 嵌入服务配置URL + API密钥 配对) ==========
LLAMACPP_EMBEDDING_URL = _get_str("LLAMACPP_EMBEDDING_URL")
LLAMACPP_API_KEY = _get_str("LLAMACPP_API_KEY")
# ========== 文档存储配置(分离配置 + 完整URI ==========
# 分离配置(优先使用)
DB_HOST = _get_str("DB_HOST")
DB_PORT = _get_int("DB_PORT")
DB_USER = _get_str("DB_USER")
DB_PASSWORD = _get_str("DB_PASSWORD")
DB_NAME = _get_str("DB_NAME")
# 完整连接字符串(直接从环境变量读取)
DB_URI = _get_str("DB_URI")
# 文档存储 URI直接从环境变量读取默认同 DB_URI
DOCSTORE_URI = _get_str("DOCSTORE_URI") or DB_URI
# ========== 文档加载器配置unstructured 库) ==========
# OCR 语言列表(逗号分隔,如 "chi_sim,eng"
RAG_OCR_LANGUAGES = _get_list_str("RAG_OCR_LANGUAGES", ["chi_sim", "eng"])
# 文档主语言列表(逗号分隔,如 "zh"
RAG_DOC_LANGUAGES = _get_list_str("RAG_DOC_LANGUAGES", ["zh"])
# ========== 索引器专用配置 ==========
# 默认索引存储路径
INDEX_STORAGE_PATH = os.getenv("INDEX_STORAGE_PATH", "./index_storage")
INDEX_STORAGE_PATH = _get_str("INDEX_STORAGE_PATH")

View File

@@ -23,6 +23,12 @@ from qdrant_client.http.exceptions import ResponseHandlingException
from .loaders import DocumentLoader
from .splitters import SplitterType, get_splitter
# 从 rag_core 导入
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent / "backend"))
from rag_core import LlamaCppEmbedder, QdrantVectorStore, create_docstore, create_parent_retriever
logger = logging.getLogger(__name__)
@@ -31,11 +37,10 @@ logger = logging.getLogger(__name__)
@dataclass
class DocstoreConfig:
"""文档存储配置(用于父块存储)。"""
connection_string: Optional[str] = None
pool_config: Optional[Dict[str, Any]] = None
max_concurrency: Optional[int] = None
pool_config: Dict[str, Any] | None = None
max_concurrency: int | None = None
# 若要从外部注入已创建好的 docstore可直接设置此字段
instance: Optional[BaseStore] = None
instance: BaseStore | None = None
@dataclass
class IndexBuilderConfig:
@@ -89,7 +94,6 @@ class IndexBuilder:
# 初始化向量存储
self.vector_store = QdrantVectorStore(
collection_name=config.collection_name,
embeddings=self.embeddings,
)
# 根据切分类型初始化相关组件
@@ -141,7 +145,6 @@ class IndexBuilder:
# 使用工厂函数创建检索器,避免重复代码
self.retriever = create_parent_retriever(
collection_name=cfg.collection_name,
embeddings=self.embeddings,
parent_splitter=self.parent_splitter,
child_splitter=self.child_splitter,
docstore=self.docstore,
@@ -158,7 +161,6 @@ class IndexBuilder:
# 使用 create_docstore 创建 PostgreSQL 存储
docstore, conn_info = create_docstore(
connection_string=cfg.connection_string,
pool_config=cfg.pool_config,
max_concurrency=cfg.max_concurrency,
)

View File

@@ -11,6 +11,9 @@ from langchain_core.documents import Document
from unstructured.documents.elements import Element
from unstructured.partition.auto import partition
# 相对导入配置
from .config import RAG_OCR_LANGUAGES, RAG_DOC_LANGUAGES
logger = logging.getLogger(__name__)
# 模块加载时设置一次环境变量,避免重复设置
@@ -47,8 +50,8 @@ class DocumentLoader:
"""
self.extract_images = extract_images
self.strategy = strategy
self.ocr_languages = ocr_languages or ["chi_sim", "eng"]
self.languages = languages or ["zh"]
self.ocr_languages = ocr_languages or RAG_OCR_LANGUAGES
self.languages = languages or RAG_DOC_LANGUAGES
self.include_page_breaks = include_page_breaks
self.pdf_infer_table_structure = pdf_infer_table_structure
self.partition_kwargs = partition_kwargs or {}

View File

@@ -1,83 +0,0 @@
#!/usr/bin/env python3
"""
测试重构后的 IndexBuilder 和 RAGRetriever
"""
import asyncio
import os
import sys
# 添加项目根目录到 Python 路径
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
from ..index_builder import IndexBuilder
from ..splitters import SplitterType
async def test_index_builder():
"""测试索引构建功能"""
print("测试索引构建功能...")
# 创建 IndexBuilder 实例
builder = IndexBuilder(
collection_name="test_collection",
splitter_type=SplitterType.PARENT_CHILD,
parent_chunk_size=1000,
child_chunk_size=200
)
# 测试文档路径
test_file = os.path.join(os.path.dirname(__file__), "..", "data", "corpus", "三国演义.txt")
if os.path.exists(test_file):
# 构建索引
print(f"正在为文件 {test_file} 构建索引...")
processed = await builder.build_from_file(test_file)
print(f"索引构建完成,处理了 {processed} 个文档")
# 获取集合信息
info = builder.get_collection_info()
print(f"集合信息: {info}")
else:
print(f"测试文件不存在: {test_file}")
# 测试搜索功能
print("\n测试搜索功能...")
try:
results = builder.search("吕布", k=3)
print(f"搜索结果数量: {len(results)}")
for i, result in enumerate(results):
print(f"\n结果 {i+1}:")
print(f"内容: {result.page_content[:100]}...")
except Exception as e:
print(f"搜索测试失败: {e}")
# 测试带父块上下文的搜索
print("\n测试带父块上下文的搜索...")
try:
results = await builder.search_with_parent_context("吕布", k=3)
print(f"搜索结果数量: {len(results)}")
for i, result in enumerate(results):
print(f"\n结果 {i+1}:")
print(f"内容: {result.page_content[:100]}...")
except Exception as e:
print(f"带父块上下文的搜索测试失败: {e}")
# 测试统一检索接口
print("\n测试统一检索接口...")
try:
# 返回父块
results_parent = await builder.retrieve("吕布", return_parent=True)
print(f"返回父块的结果数量: {len(results_parent)}")
# 返回子块
results_child = await builder.retrieve("吕布", return_parent=False)
print(f"返回子块的结果数量: {len(results_child)}")
except Exception as e:
print(f"统一检索接口测试失败: {e}")
# 关闭资源
builder.close()
print("\n测试完成")
if __name__ == "__main__":
asyncio.run(test_index_builder())

View File

@@ -1,188 +0,0 @@
"""
验证 RAG 索引完整性。
检查 Qdrant 向量库、PostgreSQL 文档存储及检索功能。
"""
import asyncio
import os
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../.."))
from dotenv import load_dotenv
load_dotenv()
QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333")
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
DB_URI = os.getenv("DB_URI", "postgresql://postgres:huang1998@115.190.121.151:5432/langgraph_db?sslmode=disable")
COLLECTION_NAME = "rag_documents"
TABLE_NAME = "parent_documents"
def check_qdrant():
"""检查 Qdrant 向量库。"""
from qdrant_client import QdrantClient
print("=" * 60)
print("Qdrant 向量库")
print("=" * 60)
client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)
# 集合列表
collections = client.get_collections().collections
print(f"\n集合数: {len(collections)}")
for c in collections:
print(f" - {c.name}")
# 目标集合信息
if not any(c.name == COLLECTION_NAME for c in collections):
print(f"\n集合 '{COLLECTION_NAME}' 不存在")
return
info = client.get_collection(COLLECTION_NAME)
print(f"\n集合 '{COLLECTION_NAME}':")
print(f" 状态: {info.status}")
print(f" 向量数: {info.points_count}")
vectors_config = info.config.params.vectors
if isinstance(vectors_config, dict):
for name, vc in vectors_config.items():
print(f" 向量 '{name}': 维度={vc.size}, 距离={vc.distance}")
else:
print(f" 向量维度: {vectors_config.size}")
# 抽样查看
print(f"\n前 3 个向量:")
points = client.scroll(
collection_name=COLLECTION_NAME,
limit=3,
with_payload=True,
with_vectors=False
)
for i, point in enumerate(points[0]):
print(f"\n {i+1}. ID: {point.id}")
payload = point.payload or {}
print(f" 内容: {payload.get('page_content', '')[:100]}...")
async def check_postgres():
"""检查 PostgreSQL 文档存储。"""
import asyncpg
print("\n" + "=" * 60)
print("PostgreSQL 文档存储")
print("=" * 60)
conn = await asyncpg.connect(dsn=DB_URI)
try:
# 表是否存在
tables = await conn.fetch(
"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'"
)
table_names = [t['table_name'] for t in tables]
if TABLE_NAME not in table_names:
print(f"\n'{TABLE_NAME}' 不存在")
return
# 统计
count = await conn.fetchval(f"SELECT COUNT(*) FROM {TABLE_NAME}")
print(f"\n'{TABLE_NAME}': {count} 条记录")
# 抽样
print(f"\n前 3 个文档:")
rows = await conn.fetch(
f"SELECT key, value FROM {TABLE_NAME} ORDER BY key LIMIT 3"
)
for i, row in enumerate(rows):
print(f"\n {i+1}. Key: {row['key']}")
val = row['value']
if isinstance(val, dict) and 'page_content' in val:
print(f" 内容: {val['page_content'][:100]}...")
# Key 前缀分布
key_prefixes = await conn.fetch(
f"""
SELECT
CASE
WHEN key LIKE '%:%' THEN split_part(key, ':', 1)
ELSE 'no_prefix'
END AS prefix,
COUNT(*) AS cnt
FROM {TABLE_NAME}
GROUP BY prefix
ORDER BY cnt DESC
LIMIT 10
"""
)
print(f"\nKey 前缀分布:")
for row in key_prefixes:
print(f" {row['prefix']}: {row['cnt']}")
finally:
await conn.close()
async def test_search():
"""测试检索功能。"""
from rag_indexer.index_builder import IndexBuilder, IndexBuilderConfig
from rag_indexer.splitters import SplitterType
print("\n" + "=" * 60)
print("检索测试")
print("=" * 60)
# 使用配置对象初始化(与默认构建方式一致)
config = IndexBuilderConfig(
collection_name=COLLECTION_NAME,
splitter_type=SplitterType.PARENT_CHILD,
)
builder = IndexBuilder(config)
# 确保检索器已初始化
if builder.retriever is None:
print("错误: 检索器未初始化,请检查切分策略")
return
query = input("\n查询 (回车使用默认): ").strip() or "你好"
print(f"\n查询: {query}")
# 标准检索(返回父块,因为 ParentDocumentRetriever 默认返回父块)
print("\n--- 标准检索 (返回父块) ---")
results = await builder.retriever.ainvoke(query)
for i, doc in enumerate(results):
content = doc.page_content[:200] if hasattr(doc, 'page_content') else str(doc)[:200]
print(f"\n {i+1}. {content}...")
if hasattr(doc, 'metadata'):
source = doc.metadata.get('source', '')
if source:
print(f" 来源: {source}")
# 若需要仅返回子块,可以临时修改检索器的 search_type
# 注意ParentDocumentRetriever 的 search_type 默认为 "similarity"
print("\n--- 检索子块 (通过修改检索器参数) ---")
# 创建一个新的检索器副本,设置为返回子块
# 简单起见,直接调用 vectorstore 进行相似度搜索获取子块
vectorstore = builder.vector_store.get_langchain_vectorstore()
sub_results = await vectorstore.asimilarity_search(query, k=3)
for i, doc in enumerate(sub_results):
content = doc.page_content[:200] if hasattr(doc, 'page_content') else str(doc)[:200]
print(f"\n {i+1}. {content}...")
if hasattr(doc, 'metadata'):
parent_id = doc.metadata.get('parent_id', '')
if parent_id:
print(f" 父块 ID: {parent_id}")
async def main():
check_qdrant()
await check_postgres()
await test_search()
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -290,9 +290,9 @@ start_backend() {
source .env 2>/dev/null || true
set +a
export PYTHONPATH="$PROJECT_DIR"
export PYTHONPATH="$PROJECT_DIR:$PROJECT_DIR/backend"
export BACKEND_PORT=8079
python app/backend.py &
python backend/app/backend.py &
BACKEND_PID=$!
echo -e "${GREEN}✓ 后端服务已启动 (PID: $BACKEND_PID)${NC}"
sleep 2
@@ -307,7 +307,7 @@ start_frontend() {
source .env 2>/dev/null || true
set +a
export PYTHONPATH="$PROJECT_DIR"
export PYTHONPATH="$PROJECT_DIR:$PROJECT_DIR/backend"
streamlit run frontend/src/frontend_main.py &
FRONTEND_PID=$!
echo -e "${GREEN}✓ 前端服务已启动 (PID: $FRONTEND_PID)${NC}"

View File

@@ -6,20 +6,22 @@
import asyncio
import os
from .config import DB_URI
import sys
import uuid
from dotenv import load_dotenv
# 添加项目根目录到 Python 路径 (现在文件在 backend/app/ 下backend 就是根)
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
# 添加项目根目录和 backend 目录到 Python 路径
project_root = os.path.join(os.path.dirname(__file__), "..")
backend_dir = os.path.join(project_root, "backend")
sys.path.insert(0, project_root)
sys.path.insert(0, backend_dir)
load_dotenv()
from backend.app.config import DB_URI
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from ..agent import AIAgentService
from ..agent.history import ThreadHistoryService
from ..logger import info, warning, error
from backend.app.agent.service import AIAgentService
from backend.app.agent.history import ThreadHistoryService
from backend.app.logger import info, warning, error
# PostgreSQL 连接字符串

View File

@@ -5,10 +5,15 @@ import sys
import numpy as np
from dotenv import load_dotenv
from qdrant_client import QdrantClient
from backend.rag_core import LlamaCppEmbedder
# 添加项目根目录和 backend 目录到 Python 路径
project_root = os.path.join(os.path.dirname(__file__), "..")
backend_dir = os.path.join(project_root, "backend")
sys.path.insert(0, project_root)
sys.path.insert(0, backend_dir)
load_dotenv()
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../.."))
from rag_core import LlamaCppEmbedder
QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333")
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")

69
test/test_frontend.py Normal file
View File

@@ -0,0 +1,69 @@
#!/usr/bin/env python3
"""
前端快速测试脚本
验证前端导入是否正常工作
"""
import sys
import os
# 添加必要的路径
project_root = os.path.dirname(os.path.abspath(__file__))
frontend_src = os.path.join(project_root, "frontend", "src")
backend_dir = os.path.join(project_root, "backend")
sys.path.insert(0, project_root)
sys.path.insert(0, frontend_src)
sys.path.insert(0, backend_dir)
print("=" * 60)
print("前端导入测试")
print("=" * 60)
# 测试 1: 直接导入前端模块
print("\n[测试 1] 直接导入前端模块...")
try:
from frontend.src.frontend_main import main
print("✅ frontend_main 导入成功")
except Exception as e:
print(f"❌ 导入失败: {e}")
sys.exit(1)
# 测试 2: 导入配置
print("\n[测试 2] 导入配置...")
try:
from config import config
print(f"✅ config 导入成功: page_title={config.page_title}")
except Exception as e:
print(f"❌ 导入失败: {e}")
# 测试 3: 导入状态管理
print("\n[测试 3] 导入状态管理...")
try:
from state import AppState
print("✅ AppState 导入成功")
except Exception as e:
print(f"❌ 导入失败: {e}")
# 测试 4: 导入 API 客户端
print("\n[测试 4] 导入 API 客户端...")
try:
from api_client import api_client
print("✅ api_client 导入成功")
except Exception as e:
print(f"❌ 导入失败: {e}")
# 测试 5: 导入组件
print("\n[测试 5] 导入组件...")
try:
from components.sidebar import render_sidebar
from components.chat_area import render_chat_area
from components.info_panel import render_info_panel
print("✅ 所有组件导入成功")
except Exception as e:
print(f"❌ 导入失败: {e}")
print("\n" + "=" * 60)
print("🎉 所有前端导入测试通过!")
print("=" * 60)
print("\n现在可以使用 ./scripts/start.sh both 启动完整服务")

View File

@@ -18,18 +18,14 @@ from dotenv import load_dotenv
load_dotenv()
# 添加项目根目录到路径
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../.."))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from pydantic import SecretStr
from langchain_openai import ChatOpenAI
from rag_indexer.index_builder import IndexBuilderConfig
from rag_indexer.splitters import SplitterType
from .pipeline import RAGPipeline
from .tools import create_rag_tool_sync
from pydantic import SecretStr
# 使用本地 LLM通过 OpenAI 兼容接口)
from langchain_openai import ChatOpenAI
from rag_core.retriever_factory import create_parent_retriever
load_dotenv()
from backend.app.rag.pipeline import RAGPipeline
from backend.app.rag.tools import create_rag_tool_sync
from backend.rag_core.retriever_factory import create_parent_retriever
def create_llm():
"""创建本地 vLLM 服务 LLM"""
@@ -113,7 +109,7 @@ async def demonstrate_tool_creation():
collection_name="rag_documents",
splitter_type=SplitterType.PARENT_CHILD,
)
retriever = retriever = create_parent_retriever(collection_name="rag_documents", search_k=5)
retriever = create_parent_retriever(collection_name="rag_documents", search_k=5)
# 2. 创建 LLM
llm = create_llm()

View File

@@ -0,0 +1,49 @@
#!/usr/bin/env python3
"""
测试重构后的 IndexBuilder 和 RAGRetriever
"""
import asyncio
import os
import sys
# 添加项目根目录到 Python 路径
project_root = os.path.join(os.path.dirname(__file__), "..")
sys.path.insert(0, project_root)
from rag_indexer.index_builder import IndexBuilder
from rag_indexer.splitters import SplitterType
async def test_index_builder():
"""测试索引构建功能"""
print("测试索引构建功能...")
# 创建 IndexBuilder 实例
builder = IndexBuilder(
collection_name="test_collection",
splitter_type=SplitterType.PARENT_CHILD,
parent_chunk_size=1000,
child_chunk_size=200
)
# 测试文档路径
test_file = os.path.join(os.path.dirname(__file__), "..", "data", "user_docs", "a.txt")
if os.path.exists(test_file):
# 构建索引
print(f"正在为文件 {test_file} 构建索引...")
processed = await builder.build_from_file(test_file)
print(f"索引构建完成,处理了 {processed} 个文档")
# 获取集合信息
info = builder.get_collection_info()
print(f"集合信息: {info}")
else:
print(f"测试文件不存在: {test_file}")
# 关闭资源
builder.close()
print("\n测试完成")
if __name__ == "__main__":
asyncio.run(test_index_builder())