From 3143e0e4e6f3269a27a348e5b5643c30a84338b5 Mon Sep 17 00:00:00 2001 From: root <953994191@qq.com> Date: Mon, 20 Apr 2026 15:55:58 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E5=BC=95=E7=94=A8=E9=80=BB?= =?UTF-8?q?=E8=BE=91=EF=BC=8C=E4=BF=AE=E6=94=B9=E9=95=BF=E6=9C=9F=E8=AE=B0?= =?UTF-8?q?=E5=BF=86bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/__init__.py | 4 +- app/agent/__init__.py | 7 + app/agent/history.py | 4 +- app/agent/prompts.py | 3 +- app/agent/{agent.py => service.py} | 20 +- app/backend.py | 16 +- app/config.py | 1 + app/graph/__init__.py | 8 + app/graph/graph_builder.py | 14 +- app/graph/graph_tools.py | 8 - app/graph/retrieve_memory.py | 4 +- app/graph/state.py | 4 +- app/graph_builder.py | 79 -------- app/memory/mem0_client.py | 10 +- app/nodes/finalize.py | 4 +- app/nodes/llm_call.py | 10 +- app/nodes/memory_trigger.py | 38 ++++ app/nodes/summarize.py | 4 +- app/nodes/tool_call.py | 4 +- app/rag/__init__.py | 15 +- app/rag/fusion.py | 2 +- app/rag/pipeline.py | 10 +- app/rag/reranker.py | 3 +- app/rag/test.py | 21 +- app/rag/tools.py | 4 +- app/test_backend.py | 307 +++++++++++++++++++++++++++++ frontend/__init__.py | 2 +- frontend/components/sidebar.py | 12 -- frontend/state.py | 2 +- rag_core/__init__.py | 8 +- rag_core/embedders.py | 5 +- rag_core/retriever_factory.py | 10 +- rag_core/store/__init__.py | 4 +- rag_core/store/factory.py | 2 +- rag_core/store/postgres.py | 5 +- rag_core/vector_store.py | 6 +- rag_indexer/__init__.py | 6 +- rag_indexer/index_builder.py | 11 +- scripts/start.sh | 13 +- 39 files changed, 444 insertions(+), 246 deletions(-) create mode 100644 app/agent/__init__.py rename app/agent/{agent.py => service.py} (90%) create mode 100644 app/graph/__init__.py delete mode 100644 app/graph_builder.py create mode 100644 app/nodes/memory_trigger.py create mode 100644 app/test_backend.py diff --git a/app/__init__.py b/app/__init__.py index 0df91f3..2d07ab9 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -2,7 +2,7 @@ AI Agent 应用模块 """ -from .agent import AIAgentService -from .graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME +from app.agent import AIAgentService +from app.graph.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME __all__ = ["AIAgentService", "AVAILABLE_TOOLS", "TOOLS_BY_NAME"] diff --git a/app/agent/__init__.py b/app/agent/__init__.py new file mode 100644 index 0000000..055f494 --- /dev/null +++ b/app/agent/__init__.py @@ -0,0 +1,7 @@ +""" +Agent 子模块 +""" + +from app.agent.service import AIAgentService + +__all__ = ["AIAgentService"] diff --git a/app/agent/history.py b/app/agent/history.py index d814aad..09f7124 100644 --- a/app/agent/history.py +++ b/app/agent/history.py @@ -3,11 +3,9 @@ 利用 LangGraph 的 checkpointer 获取对话历史和摘要 """ -from typing import List, Dict, Any, Optional -import logging +from typing import List, Dict, Any from app.logger import error # 保持兼容,或者替换为 logger - class ThreadHistoryService: """线程历史查询服务""" diff --git a/app/agent/prompts.py b/app/agent/prompts.py index 990e634..8b05050 100644 --- a/app/agent/prompts.py +++ b/app/agent/prompts.py @@ -1,6 +1,5 @@ # app/prompts.py from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder -from langchain_core.tools import BaseTool def create_system_prompt(tools: list = None) -> ChatPromptTemplate: """ @@ -11,7 +10,7 @@ def create_system_prompt(tools: list = None) -> ChatPromptTemplate: tool_descs = [] for tool in tools: # 提取工具名称和描述的第一行 - name = getattr(tool, 'name', tool.__name__) + name = getattr(tool, 'name', None) or getattr(tool, '__name__', 'unknown_tool') desc = (tool.description or "").split('\n')[0] tool_descs.append(f"- {name}: {desc}") tools_section = "\n".join(tool_descs) diff --git a/app/agent/agent.py b/app/agent/service.py similarity index 90% rename from app/agent/agent.py rename to app/agent/service.py index cefb43f..1d6cc14 100644 --- a/app/agent/agent.py +++ b/app/agent/service.py @@ -3,27 +3,17 @@ AI Agent 服务类 - 支持多模型动态切换 接收外部传入的 checkpointer,不负责管理连接生命周期 """ -import os import json from dotenv import load_dotenv -from langchain_community.chat_models import ChatZhipuAI -from langchain_openai import ChatOpenAI, OpenAIEmbeddings -from pydantic import SecretStr -from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver - # 本地模块 -from app.graph_builder import GraphBuilder, GraphContext -from app.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME -from app.rag import RAGPipeline -from app.rag.tools import create_rag_tool_sync -from rag_core import create_parent_retriever -from app.llm_factory import LLMFactory -from app.rag_initializer import init_rag_tool -from app.logger import debug, info, warning, error +from app.graph.graph_builder import GraphBuilder, GraphContext +from app.graph.graph_tools import AVAILABLE_TOOLS, TOOLS_BY_NAME +from app.agent.llm_factory import LLMFactory +from app.agent.rag_initializer import init_rag_tool +from app.logger import info, warning load_dotenv() - class AIAgentService: def __init__(self, checkpointer): self.checkpointer = checkpointer diff --git a/app/backend.py b/app/backend.py index 20da307..b60269e 100644 --- a/app/backend.py +++ b/app/backend.py @@ -16,7 +16,7 @@ from pydantic import BaseModel from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver from app.agent import AIAgentService from app.history import ThreadHistoryService -from app.logger import debug, info, warning, error +from app.logger import info, error # 加载 .env 文件 load_dotenv() @@ -28,7 +28,6 @@ DB_URI = os.getenv( "postgresql://postgres:huang1998@ai-postgres:5432/langgraph_db?sslmode=disable" ) - @asynccontextmanager async def lifespan(app: FastAPI): """应用生命周期管理:创建并注入全局服务""" @@ -53,7 +52,6 @@ async def lifespan(app: FastAPI): # 5. 关闭时自动清理数据库连接(async with 负责) info("🛑 应用关闭,数据库连接池已释放") - app = FastAPI(lifespan=lifespan) # CORS 中间件(允许前端跨域) @@ -65,14 +63,12 @@ app.add_middleware( allow_headers=["*"], ) - # ========== 健康检查端点 ========== @app.get("/health") async def health_check(): """健康检查端点,用于 Docker 和 CI/CD 监控""" return {"status": "ok", "service": "ai-agent-backend"} - # ========== Pydantic 模型 ========== class ChatRequest(BaseModel): message: str @@ -80,7 +76,6 @@ class ChatRequest(BaseModel): model: str = "zhipu" user_id: str = "default_user" - class ChatResponse(BaseModel): reply: str thread_id: str @@ -90,18 +85,15 @@ class ChatResponse(BaseModel): total_tokens: int = 0 elapsed_time: float = 0.0 - # ========== 依赖注入函数 ========== def get_agent_service(request: Request) -> AIAgentService: """从 app.state 中获取全局 AIAgentService 实例""" return request.app.state.agent_service - def get_history_service(request: Request) -> ThreadHistoryService: """从 app.state 中获取全局 ThreadHistoryService 实例""" return request.app.state.history_service - # ========== HTTP 端点 ========== @app.post("/chat", response_model=ChatResponse) async def chat_endpoint( @@ -135,7 +127,6 @@ async def chat_endpoint( elapsed_time=elapsed_time ) - # ========== 历史查询接口 ========== @app.get("/threads") async def list_threads( @@ -147,7 +138,6 @@ async def list_threads( threads = await history_service.get_user_threads(user_id, limit) return {"threads": threads} - @app.get("/thread/{thread_id}/messages") async def get_thread_messages( thread_id: str, @@ -158,7 +148,6 @@ async def get_thread_messages( messages = await history_service.get_thread_messages(thread_id) return {"messages": messages} - @app.get("/thread/{thread_id}/summary") async def get_thread_summary( thread_id: str, @@ -169,7 +158,6 @@ async def get_thread_summary( summary = await history_service.get_thread_summary(thread_id) return summary - # ========== 流式对话接口 ========== @app.post("/chat/stream") async def chat_stream_endpoint( @@ -204,7 +192,6 @@ async def chat_stream_endpoint( } ) - # ========== WebSocket 端点(可选) ========== @app.websocket("/ws") async def websocket_endpoint( @@ -228,7 +215,6 @@ async def websocket_endpoint( except WebSocketDisconnect: pass - if __name__ == "__main__": import uvicorn # 使用环境变量或默认端口 8079(避免与 llama.cpp 的 8081 端口冲突) diff --git a/app/config.py b/app/config.py index d35e467..ad70ab2 100644 --- a/app/config.py +++ b/app/config.py @@ -18,6 +18,7 @@ MEMORY_SUMMARIZE_INTERVAL = int(os.getenv("MEMORY_SUMMARIZE_INTERVAL", "10")) # Qdrant 向量数据库地址 QDRANT_URL = os.getenv("QDRANT_URL", "http://127.0.0.1:6333") QDRANT_COLLECTION_NAME = os.getenv("QDRANT_COLLECTION_NAME", "mem0_user_memories") +QDRANT_API_KEY = os.getenv("QDRANT_API_KEY", "your-qdrant-api-key") # llama.cpp Embedding 服务地址 (用于 Mem0 的向量化) LLAMACPP_EMBEDDING_URL = os.getenv("LLAMACPP_EMBEDDING_URL", "http://127.0.0.1:8082/v1") diff --git a/app/graph/__init__.py b/app/graph/__init__.py new file mode 100644 index 0000000..3b9caf2 --- /dev/null +++ b/app/graph/__init__.py @@ -0,0 +1,8 @@ +""" +Graph 子模块 +""" + +from app.graph.graph_builder import GraphBuilder +from app.graph.state import MessagesState, GraphContext + +__all__ = ["GraphBuilder", "MessagesState", "GraphContext"] diff --git a/app/graph/graph_builder.py b/app/graph/graph_builder.py index 11f9c9d..c7f5d99 100644 --- a/app/graph/graph_builder.py +++ b/app/graph/graph_builder.py @@ -5,18 +5,17 @@ LangGraph 状态图构建模块 - 精简版,仅负责组装图 from langchain_core.language_models import BaseLLM from langgraph.graph import StateGraph, START, END - -# 本地模块 from app.graph.state import MessagesState, GraphContext from app.nodes import ( + should_continue, create_llm_call_node, create_tool_call_node, create_retrieve_memory_node, create_summarize_node, - should_continue + finalize_node, ) +from app.nodes.memory_trigger import memory_trigger_node, set_mem0_client from app.memory import Mem0Client -from app.nodes.finalize import finalize_node class GraphBuilder: @@ -45,6 +44,9 @@ class GraphBuilder: Returns: StateGraph 实例 """ + # 注入全局客户端 + set_mem0_client(self.mem0_client) + builder = StateGraph(MessagesState, context_schema=GraphContext) # ⭐ 通过工厂函数创建节点(依赖注入) @@ -55,6 +57,7 @@ class GraphBuilder: # 添加节点 builder.add_node("retrieve_memory", retrieve_memory_node) + builder.add_node("memory_trigger", memory_trigger_node) builder.add_node("llm_call", llm_call_node) builder.add_node("tool_node", tool_call_node) builder.add_node("summarize", summarize_node) @@ -62,7 +65,8 @@ class GraphBuilder: # 添加边 builder.add_edge(START, "retrieve_memory") - builder.add_edge("retrieve_memory", "llm_call") + builder.add_edge("retrieve_memory", "memory_trigger") + builder.add_edge("memory_trigger", "llm_call") builder.add_conditional_edges( "llm_call", should_continue, diff --git a/app/graph/graph_tools.py b/app/graph/graph_tools.py index 6db7668..1cc1e17 100644 --- a/app/graph/graph_tools.py +++ b/app/graph/graph_tools.py @@ -3,7 +3,6 @@ """ # 标准库 -import os from pathlib import Path # 第三方库 @@ -13,7 +12,6 @@ import requests from bs4 import BeautifulSoup from langchain_core.tools import tool - def _file_allow_check(filename: str) -> Path: """检查用户文件名是否位于允许目录 './user_docs' 下,防止路径遍历攻击。""" allowed_dir = Path("./user_docs").resolve() @@ -28,13 +26,11 @@ def _file_allow_check(filename: str) -> Path: return file_path - @tool def get_current_temperature(location: str) -> str: """获取指定地点的当前温度。""" return f'当前{location}的温度为25℃' - @tool def read_local_file(filename: str) -> str: """读取用户指定名称的本地文本文件内容并返回摘要。""" @@ -46,7 +42,6 @@ def read_local_file(filename: str) -> str: except Exception as e: return f"读取文件时出错:{str(e)}" - @tool def read_pdf_summary(filename: str) -> str: """读取PDF文件并返回内容文本摘要。""" @@ -61,7 +56,6 @@ def read_pdf_summary(filename: str) -> str: except Exception as e: return f"读取PDF出错:{e}" - @tool def read_excel_as_markdown(filename: str) -> str: """读取Excel文件,并将其主要数据转换为Markdown表格格式。""" @@ -73,7 +67,6 @@ def read_excel_as_markdown(filename: str) -> str: except Exception as e: return f"读取Excel出错:{e}" - @tool def fetch_webpage_content(url: str) -> str: """抓取给定URL的网页正文内容,并返回清晰的纯文本。""" @@ -91,7 +84,6 @@ def fetch_webpage_content(url: str) -> str: except Exception as e: return f"抓取网页时出错:{str(e)}" - # 工具列表和映射(全局常量) AVAILABLE_TOOLS = [ get_current_temperature, diff --git a/app/graph/retrieve_memory.py b/app/graph/retrieve_memory.py index 61434d0..ef419d3 100644 --- a/app/graph/retrieve_memory.py +++ b/app/graph/retrieve_memory.py @@ -4,15 +4,13 @@ """ from typing import Any, Dict -from langgraph.runtime import Runtime # 本地模块 -from app.graph.state import MessagesState, GraphContext +from app.graph.state import MessagesState from app.memory.mem0_client import Mem0Client from app.utils.logging import log_state_change from app.logger import debug - def create_retrieve_memory_node(mem0_client: Mem0Client): """ 工厂函数:创建记忆检索节点 diff --git a/app/graph/state.py b/app/graph/state.py index 61463fe..2fd214e 100644 --- a/app/graph/state.py +++ b/app/graph/state.py @@ -4,12 +4,11 @@ LangGraph 状态定义模块 """ import operator -from typing import Annotated, Any +from typing import Annotated from typing_extensions import TypedDict from dataclasses import dataclass from langchain_core.messages import AnyMessage - class MessagesState(TypedDict): """对话状态类型定义""" messages: Annotated[list[AnyMessage], operator.add] @@ -19,7 +18,6 @@ class MessagesState(TypedDict): last_elapsed_time: float # 本次调用耗时(秒) turns_since_last_summary: int # 距离上次生成摘要的轮数 - @dataclass class GraphContext: """图执行上下文""" diff --git a/app/graph_builder.py b/app/graph_builder.py deleted file mode 100644 index 11f9c9d..0000000 --- a/app/graph_builder.py +++ /dev/null @@ -1,79 +0,0 @@ -""" -LangGraph 状态图构建模块 - 精简版,仅负责组装图 -所有节点逻辑已拆分到独立模块 -""" - -from langchain_core.language_models import BaseLLM -from langgraph.graph import StateGraph, START, END - -# 本地模块 -from app.graph.state import MessagesState, GraphContext -from app.nodes import ( - create_llm_call_node, - create_tool_call_node, - create_retrieve_memory_node, - create_summarize_node, - should_continue -) -from app.memory import Mem0Client -from app.nodes.finalize import finalize_node - - -class GraphBuilder: - """LangGraph 状态图构建器 - 仅负责组装图""" - - def __init__(self, llm: BaseLLM, tools: list, tools_by_name: dict): - """ - 初始化构建器 - - Args: - llm: 大语言模型实例 - tools: 工具列表 - tools_by_name: 名称到工具函数的映射 - """ - self.llm = llm - self.tools = tools - self.tools_by_name = tools_by_name - - # ⭐ 创建 Mem0 客户端(懒加载,首次使用时初始化) - self.mem0_client = Mem0Client(llm) - - def build(self) -> StateGraph: - """ - 构建未编译的状态图 - - Returns: - StateGraph 实例 - """ - builder = StateGraph(MessagesState, context_schema=GraphContext) - - # ⭐ 通过工厂函数创建节点(依赖注入) - retrieve_memory_node = create_retrieve_memory_node(self.mem0_client) - llm_call_node = create_llm_call_node(self.llm, self.tools) - tool_call_node = create_tool_call_node(self.tools_by_name) - summarize_node = create_summarize_node(self.mem0_client) - - # 添加节点 - builder.add_node("retrieve_memory", retrieve_memory_node) - builder.add_node("llm_call", llm_call_node) - builder.add_node("tool_node", tool_call_node) - builder.add_node("summarize", summarize_node) - builder.add_node("finalize", finalize_node) - - # 添加边 - builder.add_edge(START, "retrieve_memory") - builder.add_edge("retrieve_memory", "llm_call") - builder.add_conditional_edges( - "llm_call", - should_continue, - { - "tool_node": "tool_node", - "summarize": "summarize", - "finalize": "finalize" - } - ) - builder.add_edge("tool_node", "llm_call") - builder.add_edge("summarize", "finalize") - builder.add_edge("finalize", END) - - return builder \ No newline at end of file diff --git a/app/memory/mem0_client.py b/app/memory/mem0_client.py index 60a4274..f003610 100644 --- a/app/memory/mem0_client.py +++ b/app/memory/mem0_client.py @@ -4,13 +4,12 @@ Mem0 记忆层客户端封装模块 """ import asyncio -from typing import Optional, List, Dict, Any +from typing import Optional, List, Dict from mem0 import AsyncMemory -from app.config import QDRANT_URL, QDRANT_COLLECTION_NAME, LLAMACPP_EMBEDDING_URL, LLAMACPP_API_KEY +from app.config import QDRANT_URL, QDRANT_COLLECTION_NAME, QDRANT_API_KEY, LLAMACPP_EMBEDDING_URL, LLAMACPP_API_KEY from app.logger import info, warning, error - class Mem0Client: """Mem0 异步客户端封装类""" @@ -37,8 +36,9 @@ class Mem0Client: "provider": "qdrant", "config": { "url": QDRANT_URL, # 直接使用完整 URL + "api_key": QDRANT_API_KEY, "collection_name": QDRANT_COLLECTION_NAME, - "embedding_model_dims": 768, + "embedding_model_dims": 1024, } }, "llm": { @@ -50,7 +50,7 @@ class Mem0Client: "embedder": { "provider": "openai", "config": { - "model": "embeddinggemma-300M-Q8_0", + "model": "Qwen3-Embedding-0.6B-Q8_0", "api_key": LLAMACPP_API_KEY, "openai_base_url": LLAMACPP_EMBEDDING_URL, }, diff --git a/app/nodes/finalize.py b/app/nodes/finalize.py index ea6e1dc..87bd746 100644 --- a/app/nodes/finalize.py +++ b/app/nodes/finalize.py @@ -4,15 +4,13 @@ """ from typing import Any, Dict -from langgraph.runtime import Runtime from langgraph.config import get_stream_writer # 本地模块 -from app.graph.state import MessagesState, GraphContext +from app.graph.state import MessagesState from app.utils.logging import log_state_change from app.logger import info, error - from langchain_core.runnables.config import RunnableConfig async def finalize_node(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]: diff --git a/app/nodes/llm_call.py b/app/nodes/llm_call.py index 22abfed..f61cd51 100644 --- a/app/nodes/llm_call.py +++ b/app/nodes/llm_call.py @@ -3,21 +3,17 @@ LLM 调用节点模块 负责调用大语言模型并处理响应 """ -import asyncio import time from typing import Any, Dict from langchain_core.language_models import BaseLLM from langchain_core.messages import AIMessage -from langchain_core.runnables import RunnableLambda -from langgraph.runtime import Runtime # 本地模块 -from app.graph.state import MessagesState, GraphContext -from app.prompts import create_system_prompt -from app.utils.logging import log_state_change, print_llm_input +from app.graph.state import MessagesState +from app.agent.prompts import create_system_prompt +from app.utils.logging import log_state_change from app.logger import debug, info, error - def create_llm_call_node(llm: BaseLLM, tools: list): """ 工厂函数:创建 LLM 调用节点 diff --git a/app/nodes/memory_trigger.py b/app/nodes/memory_trigger.py new file mode 100644 index 0000000..77078ed --- /dev/null +++ b/app/nodes/memory_trigger.py @@ -0,0 +1,38 @@ +from typing import Any, Dict +from langchain_core.runnables.config import RunnableConfig +from app.graph.state import MessagesState +from app.memory.mem0_client import Mem0Client +from app.logger import info + +# 全局变量,在 GraphBuilder 中注入 +_mem0_client: Mem0Client = None + +def set_mem0_client(client: Mem0Client): + global _mem0_client + _mem0_client = client + +async def memory_trigger_node(state: MessagesState, config: RunnableConfig) -> Dict[str, Any]: + """检测用户消息中的记忆指令,若命中则主动调用 Mem0 存储""" + if _mem0_client is None: + return {} + + messages = state.get("messages", []) + if not messages: + return {} + + last_msg = messages[-1] + content = last_msg.content if hasattr(last_msg, 'content') else str(last_msg) + + # 触发词(可自行扩展) + trigger_words = ["记住", "记下", "保存", "备忘", "记录下", "别忘了"] + if any(word in content for word in trigger_words): + user_id = config.get("metadata", {}).get("user_id", "default_user") + # 确保 Mem0 已初始化 + if not _mem0_client._initialized: + await _mem0_client.initialize() + # 将用户消息作为事实来源提交给 Mem0 + mem0_messages = [{"role": "user", "content": content}] + await _mem0_client.add_memories(mem0_messages, user_id=user_id) + info(f"📌 检测到记忆指令,已主动触发 Mem0 存储") + + return {} # 不修改状态 \ No newline at end of file diff --git a/app/nodes/summarize.py b/app/nodes/summarize.py index 8b39baa..5c3dd6c 100644 --- a/app/nodes/summarize.py +++ b/app/nodes/summarize.py @@ -4,15 +4,13 @@ """ from typing import Any, Dict -from langgraph.runtime import Runtime # 本地模块 -from app.graph.state import MessagesState, GraphContext +from app.graph.state import MessagesState from app.memory.mem0_client import Mem0Client from app.utils.logging import log_state_change from app.logger import debug, info, error, warning - def create_summarize_node(mem0_client: Mem0Client): """ 工厂函数:创建记忆存储节点 diff --git a/app/nodes/tool_call.py b/app/nodes/tool_call.py index 1c9d55f..5aa5bdf 100644 --- a/app/nodes/tool_call.py +++ b/app/nodes/tool_call.py @@ -6,15 +6,13 @@ import asyncio from typing import Any, Dict from langchain_core.messages import AIMessage, ToolMessage -from langgraph.runtime import Runtime from langgraph.config import get_stream_writer # 本地模块 -from app.graph.state import MessagesState, GraphContext +from app.graph.state import MessagesState from app.utils.logging import log_state_change from app.logger import debug, info - def create_tool_call_node(tools_by_name: Dict[str, Any]): """ 工厂函数:创建工具执行节点 diff --git a/app/rag/__init__.py b/app/rag/__init__.py index 1438604..dca5fed 100644 --- a/app/rag/__init__.py +++ b/app/rag/__init__.py @@ -13,7 +13,7 @@ RAG 检索与生成模块 用户查询 → 多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档 示例用法: - >>> from app.rag import RAGPipeline, create_rag_tool + >>> from app.rag.rag import RAGPipeline, create_rag_tool >>> from rag_indexer.builder import IndexBuilder, IndexBuilderConfig >>> from langchain_openai import ChatOpenAI >>> @@ -34,16 +34,16 @@ RAG 检索与生成模块 >>> rag_tool = create_rag_tool(retriever=retriever, llm=llm) """ -from .retriever import ( +from app.rag.retriever import ( create_base_retriever, create_hybrid_retriever, create_qdrant_client, ) -from .reranker import LLaMaCPPReranker -from .query_transform import MultiQueryGenerator -from .fusion import reciprocal_rank_fusion -from .pipeline import RAGPipeline -from .tools import create_rag_tool, create_rag_tool_sync +from app.rag.reranker import LLaMaCPPReranker +from app.rag.query_transform import MultiQueryGenerator +from app.rag.fusion import reciprocal_rank_fusion +from app.rag.pipeline import RAGPipeline +from app.rag.tools import create_rag_tool_sync __all__ = [ @@ -65,6 +65,5 @@ __all__ = [ "RAGPipeline", # 工具创建(供 Agent 使用) - "create_rag_tool", "create_rag_tool_sync", ] \ No newline at end of file diff --git a/app/rag/fusion.py b/app/rag/fusion.py index 777cc24..ddf8f42 100644 --- a/app/rag/fusion.py +++ b/app/rag/fusion.py @@ -1,6 +1,6 @@ # rag/fusion.py -from typing import List, Dict, Tuple +from typing import List, Dict from langchain_core.documents import Document def reciprocal_rank_fusion( diff --git a/app/rag/pipeline.py b/app/rag/pipeline.py index c8d95d6..5adab4a 100644 --- a/app/rag/pipeline.py +++ b/app/rag/pipeline.py @@ -2,15 +2,13 @@ import asyncio import os -from typing import List, Optional +from typing import List from langchain_core.documents import Document from langchain_core.language_models import BaseLanguageModel -from .retriever import create_qdrant_client # 可能不需要直接使用 -from .reranker import LLaMaCPPReranker -from .query_transform import MultiQueryGenerator -from .fusion import reciprocal_rank_fusion - +from app.rag.reranker import LLaMaCPPReranker +from app.rag.query_transform import MultiQueryGenerator +from app.rag.fusion import reciprocal_rank_fusion class RAGPipeline: """ diff --git a/app/rag/reranker.py b/app/rag/reranker.py index b6f7d4a..925e283 100644 --- a/app/rag/reranker.py +++ b/app/rag/reranker.py @@ -2,9 +2,8 @@ 重排序器模块 (适配版) 使用远程 llama.cpp 服务 (兼容 OpenAI Rerank API) 替代本地 Cross-Encoder """ -import os import requests -from typing import List, Optional +from typing import List from langchain_core.documents import Document class LLaMaCPPReranker: diff --git a/app/rag/test.py b/app/rag/test.py index 80d2255..ff9817a 100644 --- a/app/rag/test.py +++ b/app/rag/test.py @@ -11,7 +11,6 @@ RAG 系统使用示例(重构版) import asyncio import sys import os -from pathlib import Path from dotenv import load_dotenv @@ -19,12 +18,12 @@ from dotenv import load_dotenv load_dotenv() # 添加项目根目录到路径 -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../..")) -from rag_indexer.index_builder import IndexBuilder, IndexBuilderConfig +from rag_indexer.index_builder import IndexBuilderConfig from rag_indexer.splitters import SplitterType -from rag.pipeline import RAGPipeline -from rag.tools import create_rag_tool +from app.rag.pipeline import RAGPipeline +from app.rag.tools import create_rag_tool_sync from pydantic import SecretStr # 使用本地 LLM(通过 OpenAI 兼容接口) from langchain_openai import ChatOpenAI @@ -32,7 +31,6 @@ from rag_core.retriever_factory import create_parent_retriever load_dotenv() - def create_llm(): """创建本地 vLLM 服务 LLM""" vllm_base_url = os.getenv( @@ -60,8 +58,7 @@ async def demonstrate_full_pipeline(): print("演示:固定流程 RAG 检索(多路改写 + RRF + 重排序 + 父文档)") print("=" * 60) - - retriever = retriever = create_parent_retriever(collection_name="my_docs", search_k=5) + retriever = create_parent_retriever(collection_name="rag_documents", search_k=5) if retriever is None: print("错误:检索器未初始化,请确保索引已构建。") @@ -103,7 +100,6 @@ async def demonstrate_full_pipeline(): import traceback traceback.print_exc() - async def demonstrate_tool_creation(): """ 演示创建 RAG 工具(供 Agent 使用) @@ -119,12 +115,11 @@ async def demonstrate_tool_creation(): ) retriever = retriever = create_parent_retriever(collection_name="rag_documents", search_k=5) - # 2. 创建 LLM llm = create_llm() # 3. 创建工具 - rag_tool = create_rag_tool( + rag_tool = create_rag_tool_sync( retriever=retriever, llm=llm, num_queries=3, @@ -136,18 +131,16 @@ async def demonstrate_tool_creation(): print(f"工具描述: {rag_tool.description[:100]}...") # 4. 模拟 Agent 调用工具 - query = "请告诉我 RAG 系统的核心组件有哪些?" + query = "请告诉我 打虎英雄是谁?" print(f"\n模拟调用: {query}") print("-" * 40) result = await rag_tool.ainvoke({"query": query}) print(result[:800] + "..." if len(result) > 800 else result) - async def main(): await demonstrate_full_pipeline() await demonstrate_tool_creation() - if __name__ == "__main__": asyncio.run(main()) \ No newline at end of file diff --git a/app/rag/tools.py b/app/rag/tools.py index 4934101..2343709 100644 --- a/app/rag/tools.py +++ b/app/rag/tools.py @@ -4,11 +4,11 @@ RAG 工具模块 将检索功能封装为 LangChain Tool,供 Agent 调用。 采用固定流水线:多路改写 → 并行检索 → RRF 融合 → 重排序 → 返回父文档。 """ -from typing import Optional, Callable +from typing import Callable from langchain_core.tools import tool from langchain_core.language_models import BaseLanguageModel from langchain_core.retrievers import BaseRetriever -from .pipeline import RAGPipeline +from app.rag.pipeline import RAGPipeline def create_rag_tool_sync( retriever: BaseRetriever, diff --git a/app/test_backend.py b/app/test_backend.py new file mode 100644 index 0000000..63e0afb --- /dev/null +++ b/app/test_backend.py @@ -0,0 +1,307 @@ +#!/usr/bin/env python3 +""" +完整后端测试 - 验证 Agent 所有功能 +包括:短期记忆、长期记忆、工具调用、流式对话、历史查询 +""" + +import asyncio +import os +import sys +import uuid +from dotenv import load_dotenv + +# 添加项目根目录到 Python 路径 +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +load_dotenv() + +from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver +from app.agent import AIAgentService +from app.agent.history import ThreadHistoryService +from app.logger import info, warning, error + +# PostgreSQL 连接字符串 +DB_URI = os.getenv( + "DB_URI", + "postgresql://postgres:***@ai-postgres:5432/langgraph_db?sslmode=disable" +) + +async def print_section(title): + """打印测试区块标题""" + print("\n" + "=" * 70) + print(f" {title}") + print("=" * 70) + +async def test_short_term_memory(agent_service): + """测试短期记忆(同一 thread_id 继续对话)""" + await print_section("测试 1: 短期记忆(Short-term Memory)") + + thread_id = str(uuid.uuid4()) + user_id = "test_user_memory" + + print(f"\n使用 thread_id: {thread_id[:8]}...") + print(f"使用 user_id: {user_id}") + + # 第一轮对话 + print("\n[第一轮] 发送消息: '我叫张三,今年28岁'") + result1 = await agent_service.process_message( + "我叫张三,今年28岁", thread_id, "local", user_id + ) + print(f"回复: {result1['reply'][:100]}...") + + # 第二轮对话 - 测试记忆 + print("\n[第二轮] 发送消息: '我叫什么名字?今年多大?'") + result2 = await agent_service.process_message( + "我叫什么名字?今年多大?", thread_id, "local", user_id + ) + print(f"回复: {result2['reply']}") + + # 验证记忆是否存在 + if "张三" in result2['reply'] or "28" in result2['reply']: + print("\n✅ 短期记忆测试通过!") + return True + else: + print("\n❌ 短期记忆测试失败!") + return False + +async def test_tool_calling(agent_service): + """测试工具调用(RAG 搜索)""" + await print_section("测试 2: 工具调用(Tool Calling)") + + thread_id = str(uuid.uuid4()) + user_id = "test_user_tools" + + print(f"\n使用 thread_id: {thread_id[:8]}...") + print(f"使用 user_id: {user_id}") + + # 发送需要 RAG 搜索的问题 + print("\n发送消息: '请告诉我,打虎英雄是谁?'") + result = await agent_service.process_message( + "请告诉我,打虎英雄是谁?", thread_id, "local", user_id + ) + print(f"回复: {result['reply'][:200]}...") + + # 检查是否调用了 RAG 工具(回复中会有水浒传相关内容) + if "武松" in result['reply'] or "李忠" in result['reply'] or "水浒传" in result['reply']: + print("\n✅ 工具调用测试通过!") + return True + else: + print("\n⚠️ 工具调用测试结果不确定,需要手动验证") + return None + +async def test_streaming(agent_service): + """测试流式对话""" + await print_section("测试 3: 流式对话(Streaming)") + + thread_id = str(uuid.uuid4()) + user_id = "test_user_stream" + + print(f"\n使用 thread_id: {thread_id[:8]}...") + print(f"使用 user_id: {user_id}") + + print("\n发送消息: '用100字介绍一下AI人工智能' (流式)...") + print("流式输出: ", end="", flush=True) + + full_reply = "" + chunk_count = 0 + + try: + async for chunk in agent_service.process_message_stream( + "用100字介绍一下AI人工智能", thread_id, "local", user_id + ): + chunk_count += 1 + if chunk.get("type") == "llm_token": + token = chunk.get("token", "") + print(token, end="", flush=True) + full_reply += token + elif chunk.get("type") == "state_update": + pass # 状态更新不显示 + + print(f"\n\n共收到 {chunk_count} 个 chunk") + print(f"完整回复长度: {len(full_reply)} 字") + + if chunk_count > 0 and len(full_reply) > 10: + print("\n✅ 流式对话测试通过!") + return True + else: + print("\n❌ 流式对话测试失败!") + return False + + except Exception as e: + print(f"\n❌ 流式对话异常: {e}") + return False + +async def test_history_service(agent_service, history_service): + """测试历史查询服务""" + await print_section("测试 4: 历史查询服务(History Service)") + + user_id = "test_user_history" + + # 先创建几个对话 + print(f"\n为 user_id={user_id} 创建测试对话...") + + thread_ids = [] + for i in range(3): + thread_id = str(uuid.uuid4()) + thread_ids.append(thread_id) + + await agent_service.process_message( + f"这是第 {i+1} 个测试对话", thread_id, "local", user_id + ) + print(f" 创建线程 {i+1}: {thread_id[:8]}...") + + # 1. 测试获取用户线程列表 + print("\n[4.1] 测试获取用户线程列表...") + threads = await history_service.get_user_threads(user_id, limit=10) + print(f" 找到 {len(threads)} 个线程") + + if len(threads) >= 3: + print(" ✅ 线程列表查询通过") + else: + print(" ⚠️ 线程数量少于预期") + + # 2. 测试获取单个线程的消息历史 + if thread_ids: + test_thread_id = thread_ids[0] + print(f"\n[4.2] 测试获取线程消息历史 (thread_id={test_thread_id[:8]}...)") + messages = await history_service.get_thread_messages(test_thread_id) + print(f" 找到 {len(messages)} 条消息") + + if len(messages) >= 2: # 至少有一问一答 + print(" ✅ 消息历史查询通过") + else: + print(" ⚠️ 消息数量少于预期") + + # 3. 测试获取线程摘要 + print(f"\n[4.3] 测试获取线程摘要...") + summary = await history_service.get_thread_summary(test_thread_id) + print(f" 摘要: {summary.get('summary', '')[:50]}...") + print(f" 消息数: {summary.get('message_count', 0)}") + + if summary.get('message_count', 0) > 0: + print(" ✅ 线程摘要查询通过") + else: + print(" ⚠️ 摘要查询结果不确定") + + return len(threads) >= 3 + +async def test_long_term_memory(agent_service): + """测试长期记忆(mem0)""" + await print_section("测试 5: 长期记忆(Long-term Memory - mem0)") + + thread_id1 = str(uuid.uuid4()) + thread_id2 = str(uuid.uuid4()) # 不同的线程 + user_id = "test_user_longterm" + + print(f"\n使用 user_id: {user_id}") + print(f"线程 1: {thread_id1[:8]}...") + print(f"线程 2: {thread_id2[:8]}...") + + # 在第一个线程中保存信息 + print("\n[线程 1] 发送消息: '记住,我的宠物名字叫小白,是一只猫'") + result1 = await agent_service.process_message( + "记住,我的宠物名字叫小白,是一只猫", thread_id1, "local", user_id + ) + print(f"回复: {result1['reply'][:100]}...") + + # 等待一下,让 mem0 保存 + await asyncio.sleep(1) + + # 在第二个线程中询问(不同的 thread_id) + print("\n[线程 2] 发送消息: '我的宠物叫什么名字?是什么动物?'") + result2 = await agent_service.process_message( + "我的宠物叫什么名字?是什么动物?", thread_id2, "local", user_id + ) + print(f"回复: {result2['reply']}") + + # 验证长期记忆 + if "小白" in result2['reply'] or "猫" in result2['reply']: + print("\n✅ 长期记忆测试通过!") + return True + else: + print("\n⚠️ 长期记忆可能未启用,或需要手动验证") + return None + +async def main(): + """主测试函数""" + print("\n" + "=" * 70) + print(" 后端完整功能测试") + print("=" * 70) + + results = {} + + try: + # 创建数据库连接和服务 + print("\n正在初始化数据库连接...") + async with AsyncPostgresSaver.from_conn_string(DB_URI) as checkpointer: + await checkpointer.setup() + print("✅ 数据库连接成功") + + # 创建服务实例 + print("\n正在初始化 Agent 服务...") + agent_service = AIAgentService(checkpointer) + await agent_service.initialize() + print("✅ Agent 服务初始化成功") + + history_service = ThreadHistoryService(checkpointer) + print("✅ 历史服务初始化成功") + + print(f"\n可用模型: {list(agent_service.graphs.keys())}") + + # 运行测试 + results["短期记忆"] = await test_short_term_memory(agent_service) + await asyncio.sleep(1) + + results["工具调用"] = await test_tool_calling(agent_service) + await asyncio.sleep(1) + + results["流式对话"] = await test_streaming(agent_service) + await asyncio.sleep(1) + + results["历史查询"] = await test_history_service(agent_service, history_service) + await asyncio.sleep(1) + + results["长期记忆"] = await test_long_term_memory(agent_service) + await asyncio.sleep(1) + + # 打印总结 + await print_section("测试总结") + print("\n测试结果:") + print("-" * 40) + + pass_count = 0 + fail_count = 0 + skip_count = 0 + + for test_name, result in results.items(): + if result is True: + status = "✅ 通过" + pass_count += 1 + elif result is False: + status = "❌ 失败" + fail_count += 1 + else: + status = "⚠️ 待验证" + skip_count += 1 + print(f" {test_name:12s}: {status}") + + print("-" * 40) + print(f"总计: {len(results)} 个测试") + print(f"通过: {pass_count}, 失败: {fail_count}, 待验证: {skip_count}") + + if fail_count == 0: + print("\n🎉 所有核心测试通过!") + else: + print(f"\n⚠️ 有 {fail_count} 个测试失败") + + except Exception as e: + error(f"\n❌ 测试运行异常: {e}") + import traceback + traceback.print_exc() + return 1 + + return 0 if fail_count == 0 else 1 + +if __name__ == "__main__": + exit_code = asyncio.run(main()) + sys.exit(exit_code) diff --git a/frontend/__init__.py b/frontend/__init__.py index 29e32df..52f6743 100644 --- a/frontend/__init__.py +++ b/frontend/__init__.py @@ -3,7 +3,7 @@ AI Agent 前端模块 采用分层架构设计,包含配置、状态、API客户端和UI组件 """ -from .logger import debug, info, warning, error +from frontend.logger import debug, info, warning, error __version__ = "2.0.0" __all__ = ["debug", "info", "warning", "error"] \ No newline at end of file diff --git a/frontend/components/sidebar.py b/frontend/components/sidebar.py index c23aa3e..174d62c 100644 --- a/frontend/components/sidebar.py +++ b/frontend/components/sidebar.py @@ -9,8 +9,6 @@ from datetime import datetime # 使用绝对导入 from frontend.state import AppState from frontend.api_client import api_client -from frontend.config import config - def render_sidebar(): """渲染左侧栏""" @@ -25,7 +23,6 @@ def render_sidebar(): st.divider() _render_user_section() - def _render_user_section(): """渲染用户登录区域""" # st.header("👤 用户") # 移除显眼的标题,改用更柔和的 caption @@ -36,7 +33,6 @@ def _render_user_section(): else: _render_user_info() - def _render_login_form(): """渲染登录表单""" username = st.text_input( @@ -54,7 +50,6 @@ def _render_login_form(): # st.info("💡 建议登录以隔离对话历史") # 移除多余色块 - def _render_user_info(): """渲染用户信息""" st.markdown(f"**当前用户**: `{AppState.get_user_id()}`") @@ -64,7 +59,6 @@ def _render_user_info(): _refresh_threads() st.rerun() - def _render_history_section(): """渲染历史对话列表""" col1, col2 = st.columns([3, 1]) @@ -76,7 +70,6 @@ def _render_history_section(): _render_thread_list() - def _render_history_actions(): """渲染历史操作按钮""" # 移除了 type="primary",让它变成普通的线框按钮,不再是大红块 @@ -84,7 +77,6 @@ def _render_history_actions(): AppState.start_new_thread() st.rerun() - def _render_thread_list(): """渲染线程列表""" # 仅在初次加载时拉取,或由外部主动调用 _refresh_threads() 更新 @@ -101,7 +93,6 @@ def _render_thread_list(): for thread in threads: _render_thread_item(thread) - def _render_thread_item(thread: dict): """ 渲染单个线程项 @@ -130,7 +121,6 @@ def _render_thread_item(thread: dict): ): _load_thread(thread_id) - def _format_time(time_str: str) -> str: """ 格式化时间字符串 @@ -150,13 +140,11 @@ def _format_time(time_str: str) -> str: except Exception: return time_str[:10] - def _refresh_threads(): """刷新历史线程列表""" threads = api_client.get_user_threads(AppState.get_user_id()) AppState.set_threads(threads) - def _load_thread(thread_id: str): """ 加载指定线程的消息历史 diff --git a/frontend/state.py b/frontend/state.py index e1d32bb..10e16b1 100644 --- a/frontend/state.py +++ b/frontend/state.py @@ -7,7 +7,7 @@ import uuid from typing import List, Dict, Any import streamlit as st -from .config import config +from frontend.config import config class AppState: diff --git a/rag_core/__init__.py b/rag_core/__init__.py index 318a066..a19afb2 100644 --- a/rag_core/__init__.py +++ b/rag_core/__init__.py @@ -4,10 +4,10 @@ RAG Core - 公共 RAG 组件包 提供嵌入模型、向量存储和文档存储的公共功能,被 rag_indexer 和 app/rag 共用。 """ -from .embedders import LlamaCppEmbedder -from .vector_store import QdrantVectorStore, QDRANT_URL, QDRANT_API_KEY -from .store import PostgresDocStore, create_docstore -from .retriever_factory import create_parent_retriever +from rag_core.embedders import LlamaCppEmbedder +from rag_core.vector_store import QdrantVectorStore, QDRANT_URL, QDRANT_API_KEY +from rag_core.store import PostgresDocStore, create_docstore +from rag_core.retriever_factory import create_parent_retriever __all__ = [ diff --git a/rag_core/embedders.py b/rag_core/embedders.py index e9a87a3..66ffa6e 100644 --- a/rag_core/embedders.py +++ b/rag_core/embedders.py @@ -5,11 +5,9 @@ import os import httpx from typing import List, Optional -from urllib.parse import urljoin from langchain_core.embeddings import Embeddings - class LlamaCppEmbedder: """通过 OpenAI 兼容 API 封装 llama.cpp 嵌入服务。""" @@ -17,7 +15,7 @@ class LlamaCppEmbedder: self, base_url: Optional[str] = None, api_key: Optional[str] = None, - model: str = "embeddinggemma-300M-Q8_0", + model: str = "Qwen3-Embedding-0.6B-Q8_0", ): self.base_url = base_url or os.getenv("LLAMACPP_EMBEDDING_URL", "http://127.0.0.1:8082") self.api_key = api_key or os.getenv("LLAMACPP_API_KEY", "") @@ -71,7 +69,6 @@ class LlamaCppEmbedder: else: raise ValueError(f"未知的嵌入 API 响应格式: {data}") - class _LlamaCppLangchainAdapter(Embeddings): """将 LlamaCppEmbedder 适配为 LangChain Embeddings 接口。""" diff --git a/rag_core/retriever_factory.py b/rag_core/retriever_factory.py index 24a77af..25dab1c 100644 --- a/rag_core/retriever_factory.py +++ b/rag_core/retriever_factory.py @@ -2,14 +2,7 @@ from langchain_core.embeddings import Embeddings from langchain_classic.retrievers import ParentDocumentRetriever from langchain_text_splitters import RecursiveCharacterTextSplitter -from rag_indexer.splitters import SplitterType, get_splitter -import asyncio -import logging -from dataclasses import dataclass, field -from pathlib import Path -from typing import List, Union, Optional, Any, Dict, Tuple -from httpx import RemoteProtocolError -from langchain_core.documents import Document +from typing import Optional from langchain_core.embeddings import Embeddings from langchain_core.stores import BaseStore from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter @@ -17,7 +10,6 @@ from langchain_classic.retrievers import ParentDocumentRetriever from rag_core import LlamaCppEmbedder, QdrantVectorStore, create_docstore - def create_parent_retriever( collection_name: str = "rag_documents", embeddings: Optional[Embeddings] = None, diff --git a/rag_core/store/__init__.py b/rag_core/store/__init__.py index 359db76..b4aab75 100644 --- a/rag_core/store/__init__.py +++ b/rag_core/store/__init__.py @@ -15,8 +15,8 @@ """ -from .postgres import PostgresDocStore -from .factory import create_docstore, get_docstore_uri, DEFAULT_DB_URI +from rag_core.store.postgres import PostgresDocStore +from rag_core.store.factory import create_docstore, get_docstore_uri, DEFAULT_DB_URI __version__ = "2.0.0" diff --git a/rag_core/store/factory.py b/rag_core/store/factory.py index c32c2c7..391b077 100644 --- a/rag_core/store/factory.py +++ b/rag_core/store/factory.py @@ -9,7 +9,7 @@ import logging from typing import Optional, Tuple from langchain_core.stores import BaseStore -from .postgres import PostgresDocStore +from rag_core.store.postgres import PostgresDocStore logger = logging.getLogger(__name__) diff --git a/rag_core/store/postgres.py b/rag_core/store/postgres.py index 5132355..23b7153 100644 --- a/rag_core/store/postgres.py +++ b/rag_core/store/postgres.py @@ -4,12 +4,10 @@ 使用 asyncpg 实现真正的异步 PostgreSQL 文档存储,支持高并发访问。 """ -from __future__ import annotations - import asyncio import json import logging -from typing import List, Dict, Any, Optional, Iterator, Tuple, Sequence, cast +from typing import List, Dict, Any, Optional, Iterator, Tuple, Sequence from langchain_core.documents import Document from langchain_core.stores import BaseStore @@ -18,7 +16,6 @@ import asyncpg logger = logging.getLogger(__name__) - class PostgresDocStore(BaseStore[str, Any]): """ 异步 PostgreSQL 文档存储实现。 diff --git a/rag_core/vector_store.py b/rag_core/vector_store.py index 13cbfbb..b92e113 100644 --- a/rag_core/vector_store.py +++ b/rag_core/vector_store.py @@ -13,7 +13,7 @@ from qdrant_client import QdrantClient from qdrant_client.http.models import Distance, VectorParams from httpx import RemoteProtocolError from qdrant_client.http.exceptions import ResponseHandlingException -from .client import create_qdrant_client +from rag_core.client import create_qdrant_client logger = logging.getLogger(__name__) @@ -35,7 +35,7 @@ class QdrantVectorStore: self._last_connection_time: Optional[float] = None if embeddings is None: - from .embedders import LlamaCppEmbedder + from rag_core.embedders import LlamaCppEmbedder embedder = LlamaCppEmbedder() self.embeddings = embedder.as_langchain_embeddings() else: @@ -96,7 +96,7 @@ class QdrantVectorStore: def create_collection(self, vector_size: Optional[int] = None, force_recreate: bool = False): """创建集合,设置合适的向量维度。""" if vector_size is None: - from .embedders import LlamaCppEmbedder + from rag_core.embedders import LlamaCppEmbedder embedder = LlamaCppEmbedder() vector_size = embedder.get_embedding_dimension() diff --git a/rag_indexer/__init__.py b/rag_indexer/__init__.py index 2a0117f..21ca58c 100644 --- a/rag_indexer/__init__.py +++ b/rag_indexer/__init__.py @@ -23,9 +23,9 @@ Offline RAG Indexer module. >>> await builder.build_from_file("document.pdf") """ -from .index_builder import IndexBuilder, IndexBuilderConfig, DocstoreConfig -from .loaders import DocumentLoader -from .splitters import SplitterType, get_splitter +from rag_indexer.index_builder import IndexBuilder, IndexBuilderConfig, DocstoreConfig +from rag_indexer.loaders import DocumentLoader +from rag_indexer.splitters import SplitterType, get_splitter # 从 rag_core 重新导出常用组件 from rag_core import ( diff --git a/rag_indexer/index_builder.py b/rag_indexer/index_builder.py index 137a674..582f8d5 100644 --- a/rag_indexer/index_builder.py +++ b/rag_indexer/index_builder.py @@ -8,24 +8,21 @@ import asyncio import logging from dataclasses import dataclass, field from pathlib import Path -from typing import List, Union, Optional, Any, Dict, Tuple +from typing import List, Union, Optional, Any, Dict from httpx import RemoteProtocolError from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.stores import BaseStore from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter -from langchain_classic.retrievers import ParentDocumentRetriever from qdrant_client.http.exceptions import ResponseHandlingException -from .loaders import DocumentLoader -from .splitters import SplitterType, get_splitter, SemanticChunkerAdapter +from rag_indexer.loaders import DocumentLoader +from rag_indexer.splitters import SplitterType, get_splitter from rag_core import LlamaCppEmbedder, QdrantVectorStore, create_docstore, create_parent_retriever - logger = logging.getLogger(__name__) - # ---------- 配置数据类 ---------- @dataclass class DocstoreConfig: @@ -36,7 +33,6 @@ class DocstoreConfig: # 若要从外部注入已创建好的 docstore,可直接设置此字段 instance: Optional[BaseStore] = None - @dataclass class IndexBuilderConfig: """索引构建器配置。""" @@ -60,7 +56,6 @@ class IndexBuilderConfig: # 其他切分器参数(当 splitter_type 非父子块时使用) extra_splitter_kwargs: Dict[str, Any] = field(default_factory=dict) - # ---------- 索引构建器 ---------- class IndexBuilder: """RAG 索引构建主流水线,支持单块切分与父子块切分。""" diff --git a/scripts/start.sh b/scripts/start.sh index 3243bc2..1d308d3 100755 --- a/scripts/start.sh +++ b/scripts/start.sh @@ -250,7 +250,7 @@ start_embedding() { echo -e "${BLUE}🚀 启动 llama.cpp Embedding 容器...${NC}" # 检查模型文件 - if [ ! -f "/home/huang/Study/AIModel/GGUF/embeddinggemma-300M-Q8_0.gguf" ]; then + if [ ! -f "/home/huang/Study/AIModel/GGUF/Qwen3-Embedding-0.6B-Q8_0.gguf" ]; then echo -e "${RED}✗ 错误:Embedding 模型文件不存在${NC}" exit 1 fi @@ -263,13 +263,16 @@ start_embedding() { --device=/dev/dri \ -v /home/huang/Study/AIModel/GGUF:/models \ -p 8082:8080 \ + -e LLAMA_ARG_CTX_SIZE=16384 \ + -e LLAMA_ARG_N_PARALLEL=1 \ + -e LLAMA_ARG_BATCH=512 \ + -e LLAMA_ARG_N_GPU_LAYERS=99 \ + -e LLAMA_ARG_API_KEY=huang1998 \ ghcr.io/ggml-org/llama.cpp:server-rocm \ - -m /models/embeddinggemma-300M-Q8_0.gguf \ + -m /models/Qwen3-Embedding-0.6B-Q8_0.gguf \ --host 0.0.0.0 \ --port 8080 \ - -ngl 99 \ - --embeddings \ - -c 512 + --embeddings echo -e "${GREEN}✓ llama.cpp Embedding 容器已启动 (端口 8082)${NC}" sleep 5