backtrader/llm_manager.py
Qihang Zhang 47dc538c67 feat(llm): 添加LLM集成及配置支持
这个提交实现了一个完整的LLM管理器,并增加了相关配置项。主要变更包括:
- 在配置文件中添加LLM相关参数设置(API密钥、基础URL等)
- 创建新的LLM管理器类,实现单例模式和与LLM的通信功能
- 优化数据库查询日志记录
2025-04-20 00:02:57 +08:00

93 lines
2.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import yaml
import requests
from typing import Dict, Any, Optional
class LLMManager:
"""LLM管理器负责加载API配置并提供与LLM交互的功能"""
_instance = None # 单例实例
def __init__(self):
self._api_key = None
self._api_base = None
self._model = None
self._temperature = None
def __new__(cls):
"""实现单例模式"""
if cls._instance is None:
cls._instance = super(LLMManager, cls).__new__(cls)
cls._instance._config = None
cls._instance._api_key = None
cls._instance._api_base = None
cls._instance._model = None
cls._instance._provider = None
return cls._instance
def initialize(self):
"""初始化LLM配置"""
from config_manager import get_config_manager
config_manager = get_config_manager()
llm_config = config_manager.get('llm', {})
# 加载LLM配置
self._api_key = llm_config.get('api_key', '')
self._api_base = llm_config.get('api_base', 'https://api.openai.com/v1')
self._model = llm_config.get('model', 'gpt-3.5-turbo')
self._temperature = llm_config.get('temperature', 0.7)
if not self._api_key:
print("警告: LLM API密钥未配置请在config.yaml中设置。")
def chat(self, content: str, prompt: Optional[str] = None) -> str:
"""
与LLM进行对话
Args:
content: 用户输入内容
prompt: 可选的系统提示词
Returns:
LLM的回复内容
"""
if self._config is None:
self.initialize()
headers = {
"Authorization": f"Bearer {self._api_key}",
"Content-Type": "application/json"
}
messages = []
# 添加系统提示(如果有)
if prompt:
messages.append({"role": "system", "content": prompt})
# 添加用户消息
messages.append({"role": "user", "content": content})
payload = {
"model": self._model,
"messages": messages,
"temperature": self._temperature,
}
response = requests.post(
f"{self._api_base}/chat/completions",
headers=headers,
json=payload
)
if response.status_code != 200:
raise Exception(f"API请求失败: {response.text}")
return response.json()["choices"][0]["message"]["content"]
# 提供简单的访问函数
def get_llm_manager() -> LLMManager:
"""获取LLM管理器实例"""
return LLMManager()