This commit is contained in:
@@ -1,73 +0,0 @@
|
||||
# app/rag_initializer.py
|
||||
from ...rag.tools import create_rag_tool
|
||||
from ...rag.retriever import create_parent_hybrid_retriever
|
||||
from ...model_services import get_embedding_service
|
||||
from backend.app.logger import info, warning
|
||||
import sys
|
||||
|
||||
# 全局 RAG 工具
|
||||
_rag_tool = None
|
||||
_initialized = False
|
||||
|
||||
|
||||
def get_rag_tool() -> callable:
|
||||
"""获取全局 RAG 工具"""
|
||||
return _rag_tool
|
||||
|
||||
|
||||
def is_initialized() -> bool:
|
||||
"""检查是否已初始化"""
|
||||
return _initialized
|
||||
|
||||
|
||||
async def init_rag_tool(force: bool = False):
|
||||
"""
|
||||
初始化 RAG 工具(注册到模块级变量,内部获取所需服务)
|
||||
|
||||
Args:
|
||||
force: 是否强制重新初始化
|
||||
|
||||
Returns:
|
||||
RAG 工具(@tool 装饰函数)或 None
|
||||
"""
|
||||
global _rag_tool, _initialized
|
||||
|
||||
# 防止重复初始化
|
||||
if _initialized and not force:
|
||||
info("[RAG] 已初始化,跳过")
|
||||
return _rag_tool
|
||||
|
||||
try:
|
||||
from backend.app.model_services.chat_services import get_chat_service
|
||||
|
||||
info("🔄 正在初始化 RAG 检索系统...")
|
||||
embeddings = get_embedding_service()
|
||||
retriever = create_parent_hybrid_retriever(
|
||||
collection_name="rag_documents",
|
||||
search_k=5,
|
||||
embeddings=embeddings,
|
||||
)
|
||||
rewrite_llm = get_chat_service()
|
||||
|
||||
rag_tool = create_rag_tool(
|
||||
retriever=retriever,
|
||||
llm=rewrite_llm,
|
||||
num_queries=3,
|
||||
rerank_top_n=5,
|
||||
)
|
||||
|
||||
_rag_tool = rag_tool
|
||||
_initialized = True
|
||||
info(f"✅ RAG 检索工具初始化成功 (id={id(rag_tool)})")
|
||||
return rag_tool
|
||||
|
||||
except Exception as e:
|
||||
warning(f"⚠️ RAG 检索工具初始化失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def reset():
|
||||
"""重置(用于测试)"""
|
||||
global _rag_tool, _initialized
|
||||
_rag_tool = None
|
||||
_initialized = False
|
||||
@@ -3,12 +3,11 @@
|
||||
Web Search Public Utility - Free, no API Key, using DuckDuckGo
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Optional
|
||||
from typing import List, Optional
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
import requests
|
||||
import warnings
|
||||
import re
|
||||
|
||||
from backend.app.logger import info
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -44,47 +43,31 @@ class WebSearchTool:
|
||||
"""
|
||||
num_results = max_results or self.max_results
|
||||
|
||||
# 方式 1: Tavily (需要 API Key,质量最高)
|
||||
# 尝试搜索方式,按优先级
|
||||
result = self._try_tavily(query, num_results)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
result = self._try_ddgs(query, num_results)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
# 兜底方案
|
||||
return self._get_mock_results(query, num_results)
|
||||
|
||||
def _try_tavily(self, query: str, max_results: int) -> Optional[List[SearchResult]]:
|
||||
"""尝试 Tavily API 搜索"""
|
||||
try:
|
||||
return self._search_tavily(query, num_results)
|
||||
return self._search_tavily(query, max_results)
|
||||
except ImportError:
|
||||
print("[WebSearch] tavily 未安装,尝试其他搜索方式")
|
||||
info("[WebSearch] tavily 未安装")
|
||||
except Exception as e:
|
||||
if "API_KEY" in str(e) or "未配置" in str(e):
|
||||
print(f"[WebSearch] Tavily API Key 未配置: {e}")
|
||||
error_msg = str(e)
|
||||
if "API_KEY" in error_msg or "未配置" in error_msg:
|
||||
info(f"[WebSearch] Tavily API Key 未配置")
|
||||
else:
|
||||
print(f"[WebSearch] Tavily 搜索失败: {e}")
|
||||
|
||||
# 方式 2: 尝试用 ddgs 包
|
||||
try:
|
||||
from ddgs import DDGS
|
||||
print(f"[WebSearch] 使用 ddgs 搜索: {query}")
|
||||
with DDGS() as ddgs:
|
||||
results = list(ddgs.text(query, max_results=num_results))
|
||||
if results:
|
||||
search_results = []
|
||||
for r in results:
|
||||
search_results.append(SearchResult(
|
||||
title=r.get("title", ""),
|
||||
url=r.get("href", ""),
|
||||
snippet=r.get("body", ""),
|
||||
source="DuckDuckGo"
|
||||
))
|
||||
print(f"[WebSearch] ddgs 返回 {len(search_results)} 条结果")
|
||||
return search_results
|
||||
except ImportError:
|
||||
print("[WebSearch] ddgs 未安装,尝试 duckduckgo-search")
|
||||
except Exception as e:
|
||||
print(f"[WebSearch] ddgs 搜索失败: {e}")
|
||||
|
||||
# 方式 3: 尝试用简单 HTTP 请求
|
||||
try:
|
||||
return self._search_http(query, num_results)
|
||||
except Exception as e:
|
||||
print(f"[WebSearch] HTTP 搜索也失败: {e}")
|
||||
|
||||
# 方式 4: 返回模拟数据作为最后兜底
|
||||
return self._search_mock(query, num_results)
|
||||
info(f"[WebSearch] Tavily 搜索失败: {e}")
|
||||
return None
|
||||
|
||||
def _search_tavily(self, query: str, max_results: int) -> List[SearchResult]:
|
||||
"""使用 Tavily API 搜索"""
|
||||
@@ -111,56 +94,40 @@ class WebSearchTool:
|
||||
source="Tavily"
|
||||
))
|
||||
|
||||
print(f"[WebSearch] Tavily 返回 {len(results)} 条结果")
|
||||
info(f"[WebSearch] Tavily 返回 {len(results)} 条结果")
|
||||
return results
|
||||
|
||||
def _search_http(self, query: str, max_results: int) -> List[SearchResult]:
|
||||
"""用简单 HTTP 请求搜索(备用方案)- 尝试多个国内源"""
|
||||
print(f"[WebSearch] 尝试 HTTP 搜索")
|
||||
|
||||
# 方式 1: 尝试百度搜索(简单方式)
|
||||
def _try_ddgs(self, query: str, max_results: int) -> Optional[List[SearchResult]]:
|
||||
"""尝试 DuckDuckGo 搜索"""
|
||||
try:
|
||||
return self._search_baidu(query, max_results)
|
||||
except Exception as e:
|
||||
print(f"[WebSearch] 百度搜索失败: {e}")
|
||||
|
||||
# 方式 2: 返回模拟数据
|
||||
return self._search_mock(query, max_results)
|
||||
from ddgs import DDGS
|
||||
|
||||
def _search_baidu(self, query: str, max_results: int) -> List[SearchResult]:
|
||||
"""尝试百度搜索"""
|
||||
import requests
|
||||
from urllib.parse import quote
|
||||
|
||||
url = f"https://www.baidu.com/s?wd={quote(query)}"
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.get(url, headers=headers, timeout=10)
|
||||
response.raise_for_status()
|
||||
|
||||
# 简单解析百度搜索结果(简化版)
|
||||
results = []
|
||||
# 这里只是示意,真实百度搜索需要更复杂的解析
|
||||
results.append(SearchResult(
|
||||
title=f"百度搜索: {query}",
|
||||
url=url,
|
||||
snippet="如需要真实搜索结果,请考虑使用百度搜索 API",
|
||||
source="百度"
|
||||
))
|
||||
return results
|
||||
except Exception as e:
|
||||
print(f"[WebSearch] 百度搜索也失败: {e}")
|
||||
raise
|
||||
with DDGS() as ddgs:
|
||||
for r in ddgs.text(query, max_results=max_results):
|
||||
results.append(SearchResult(
|
||||
title=r.get("title", ""),
|
||||
url=r.get("href", ""),
|
||||
snippet=r.get("body", ""),
|
||||
source="DuckDuckGo"
|
||||
))
|
||||
|
||||
def _search_mock(self, query: str, max_results: Optional[int] = None) -> List[SearchResult]:
|
||||
"""模拟搜索结果(兜底方案)"""
|
||||
print(f"[WebSearch] 使用模拟搜索结果 (查询: {query})")
|
||||
|
||||
# 根据查询内容生成更有意义的模拟结果
|
||||
mock_templates = [
|
||||
if results:
|
||||
info(f"[WebSearch] ddgs 返回 {len(results)} 条结果")
|
||||
return results
|
||||
|
||||
except ImportError:
|
||||
info("[WebSearch] ddgs 未安装")
|
||||
except Exception as e:
|
||||
info(f"[WebSearch] ddgs 搜索失败: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def _get_mock_results(self, query: str, max_results: Optional[int] = None) -> List[SearchResult]:
|
||||
"""获取模拟搜索结果(兜底方案)"""
|
||||
info(f"[WebSearch] 使用模拟搜索结果")
|
||||
|
||||
templates = [
|
||||
{
|
||||
"title": f"关于「{query}」的相关介绍",
|
||||
"snippet": "这是模拟结果。如需真实搜索,请检查容器网络连接或配置代理。",
|
||||
@@ -177,50 +144,48 @@ class WebSearchTool:
|
||||
"url": "https://example.com/more"
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
num = max_results or self.max_results
|
||||
results = []
|
||||
|
||||
for i, template in enumerate(mock_templates[:num]):
|
||||
|
||||
for template in templates[:num]:
|
||||
results.append(SearchResult(
|
||||
title=template["title"],
|
||||
url=template["url"],
|
||||
snippet=template["snippet"],
|
||||
source="模拟数据"
|
||||
))
|
||||
|
||||
|
||||
return results
|
||||
|
||||
def format_search_results(self, results: List[SearchResult]) -> str:
|
||||
"""
|
||||
格式化搜索结果(带引用溯源)
|
||||
|
||||
|
||||
Args:
|
||||
results: 搜索结果列表
|
||||
|
||||
|
||||
Returns:
|
||||
格式化后的 Markdown 文本
|
||||
"""
|
||||
if not results:
|
||||
return "未找到相关搜索结果"
|
||||
|
||||
lines = []
|
||||
lines.append("## 🔍 联网搜索结果\n")
|
||||
|
||||
|
||||
lines = ["## 🔍 联网搜索结果\n"]
|
||||
|
||||
for idx, result in enumerate(results, 1):
|
||||
lines.append(f"### [{idx}] {result.title}")
|
||||
lines.append(f"- 🔗 来源:[{result.url}]({result.url})")
|
||||
lines.append(f"- 📝 摘要:{result.snippet}")
|
||||
lines.append(f"- 📅 时间:{result.timestamp.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
lines.append("")
|
||||
|
||||
# 添加引用溯源说明
|
||||
|
||||
lines.append("---")
|
||||
lines.append("💡 **引用溯源说明**:")
|
||||
lines.append("- 以上搜索结果均标注了来源链接")
|
||||
lines.append("- 使用方括号数字标识引用(如 [1]、[2])")
|
||||
lines.append("- 可通过链接追溯原始信息")
|
||||
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
@@ -239,11 +204,11 @@ def get_web_search_tool() -> WebSearchTool:
|
||||
def web_search(query: str, max_results: int = 5) -> str:
|
||||
"""
|
||||
便捷函数:联网搜索并返回格式化结果
|
||||
|
||||
|
||||
Args:
|
||||
query: 搜索关键词
|
||||
max_results: 返回结果数量
|
||||
|
||||
|
||||
Returns:
|
||||
格式化后的搜索结果文本
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user