backtrader/llm_manager.py

167 lines
6.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import Optional
import requests
import time
from config_manager import get_config_manager
from logger_manager import get_logger
config_manager = get_config_manager()
logger = get_logger()
class LLMManager:
"""LLM管理器负责加载API配置并提供与LLM交互的功能"""
_instance = None # 单例实例
def __init__(self):
self._api_key = None
self._api_base = None
self._model = None
self._temperature = None
self._config = None
def __new__(cls):
"""实现单例模式"""
if cls._instance is None:
cls._instance = super(LLMManager, cls).__new__(cls)
cls._instance._config = None
cls._instance._api_key = None
cls._instance._api_base = None
cls._instance._model = None
cls._instance._provider = None
return cls._instance
def initialize(self):
"""初始化LLM配置"""
llm_config = config_manager.get('llm', {})
# 加载LLM配置
self._api_key = llm_config.get('api_key', '')
self._api_base = llm_config.get('api_base', 'https://api.openai.com/v1')
self._model = llm_config.get('model', 'gpt-3.5-turbo')
self._temperature = llm_config.get('temperature', 0.7)
self._config = llm_config
if not self._api_key:
logger.warning("警告: LLM API密钥未配置请在config.yaml中设置。")
def chat(self, content: str, prompt: Optional[str] = None, max_retries: int = 5) -> str:
"""
与LLM进行对话包含自动重试机制
Args:
content: 用户输入内容
prompt: 可选的系统提示词
max_retries: 最大重试次数默认为5
Returns:
LLM的回复内容
"""
if self._config is None:
self.initialize()
headers = {
"Authorization": f"Bearer {self._api_key}",
"Content-Type": "application/json"
}
messages = []
# 添加系统提示(如果有)
if prompt:
messages.append({"role": "system", "content": prompt})
# 添加用户消息
messages.append({"role": "user", "content": content})
payload = {
"model": self._model,
"messages": messages,
"temperature": self._temperature,
}
logger.debug(f"请求数据: {payload}")
# 实现重试机制
attempts = 0
last_exception = None
while attempts < max_retries:
try:
# 尝试直接不使用代理
if attempts > 0:
logger.info(f"尝试不使用代理进行请求 (尝试 {attempts + 1}/{max_retries})")
response = requests.post(
f"{self._api_base}/chat/completions",
headers=headers,
json=payload,
proxies=None, # 明确不使用代理
timeout=30 # 设置超时时间
)
else:
# 第一次尝试使用默认设置
response = requests.post(
f"{self._api_base}/chat/completions",
headers=headers,
json=payload,
timeout=30 # 设置超时时间
)
if response.status_code == 200:
return response.json()["choices"][0]["message"]["content"]
elif response.status_code == 429:
# 限流错误,需要等待更长时间
logger.warning(
f"API请求限流 (尝试 {attempts + 1}/{max_retries}): {response.status_code}, {response.text}")
attempts += 1
wait_time = 10 * (2 ** attempts) # 指数退避10s, 20s, 40s...
logger.info(f"等待 {wait_time} 秒后重试...")
time.sleep(wait_time)
else:
# 其他服务器错误
logger.warning(
f"请求失败 (尝试 {attempts + 1}/{max_retries}): {response.status_code}, {response.text}")
attempts += 1
wait_time = 2 * (2 ** attempts) # 指数退避4s, 8s, 16s...
logger.info(f"等待 {wait_time} 秒后重试...")
time.sleep(wait_time)
except requests.exceptions.ProxyError as e:
# 代理错误处理
last_exception = e
attempts += 1
logger.warning(f"代理连接错误 (尝试 {attempts}/{max_retries}): {str(e)}")
if attempts >= max_retries:
break
wait_time = 2 * attempts # 线性退避2s, 4s, 6s...
logger.info(f"等待 {wait_time} 秒后尝试不使用代理重试...")
time.sleep(wait_time)
except (requests.exceptions.ConnectionError,
requests.exceptions.Timeout,
requests.exceptions.ReadTimeout) as e:
# 连接错误和超时
last_exception = e
attempts += 1
logger.warning(f"连接错误或超时 (尝试 {attempts}/{max_retries}): {str(e)}")
if attempts >= max_retries:
break
wait_time = 3 * (2 ** attempts) # 指数退避6s, 12s, 24s...
logger.info(f"等待 {wait_time} 秒后重试...")
time.sleep(wait_time)
except Exception as e:
# 其他未预期的错误
last_exception = e
attempts += 1
logger.error(f"未预期的错误 (尝试 {attempts}/{max_retries}): {str(e)}")
if attempts >= max_retries:
break
wait_time = 2 * attempts # 线性退避
logger.info(f"等待 {wait_time} 秒后重试...")
time.sleep(wait_time)
# 所有重试都失败后
logger.error(f"达到最大重试次数 ({max_retries}),请求失败。最后错误: {str(last_exception)}")
return f"抱歉,请求遇到网络问题,无法获取分析结果。请检查网络设置或稍后再试。\n错误信息: {str(last_exception)}"
# 提供简单的访问函数
def get_llm_manager() -> LLMManager:
"""获取LLM管理器实例"""
return LLMManager()