2025-04-19 14:22:18 +08:00
|
|
|
import os
|
|
|
|
import yaml
|
|
|
|
from typing import Dict, Any
|
|
|
|
|
|
|
|
|
|
|
|
class ConfigManager:
|
|
|
|
"""配置管理器,负责加载和提供配置信息"""
|
|
|
|
|
|
|
|
_instance = None # 单例实例
|
|
|
|
|
|
|
|
def __new__(cls):
|
|
|
|
"""实现单例模式"""
|
|
|
|
if cls._instance is None:
|
|
|
|
cls._instance = super(ConfigManager, cls).__new__(cls)
|
|
|
|
cls._instance._config = None
|
|
|
|
return cls._instance
|
|
|
|
|
|
|
|
def load_config(self) -> Dict[str, Any]:
|
|
|
|
"""加载配置,如果配置文件不存在则创建默认配置"""
|
|
|
|
if self._config is not None:
|
|
|
|
return self._config
|
|
|
|
|
|
|
|
config_path = 'config.yaml'
|
|
|
|
|
|
|
|
if not os.path.exists(config_path):
|
|
|
|
# 如果配置文件不存在,创建默认配置
|
|
|
|
ConfigManager._create_default_config(config_path)
|
|
|
|
print(f"已创建新的配置文件 {config_path},请完善其中的配置信息后再运行。")
|
|
|
|
exit(1)
|
|
|
|
|
|
|
|
# 加载配置文件
|
|
|
|
with open(config_path, 'r', encoding='utf-8') as f:
|
|
|
|
self._config = yaml.safe_load(f)
|
|
|
|
|
|
|
|
return self._config
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def _create_default_config(config_path: str) -> None:
|
|
|
|
"""创建默认配置文件"""
|
|
|
|
default_config = {
|
|
|
|
'tushare_token': 'xxxxxxxxxxx',
|
|
|
|
'sqlite': {
|
|
|
|
'path': './data/tushare_data.db',
|
2025-04-19 17:12:01 +08:00
|
|
|
'database_name': 'tushare_data'
|
2025-04-19 14:22:18 +08:00
|
|
|
},
|
|
|
|
'log': {
|
|
|
|
'level': 'INFO',
|
|
|
|
'store': True,
|
|
|
|
'path': 'logs'
|
2025-04-20 00:02:57 +08:00
|
|
|
},
|
|
|
|
'llm': {
|
|
|
|
'api_key': '', # 需要用户填写
|
|
|
|
'api_base': 'https://api.openai.com/v1', # API基础URL
|
|
|
|
'model': 'gpt-3.5-turbo', # 默认模型
|
|
|
|
'temperature': 0.7, # 温度参数
|
2025-04-19 14:22:18 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
# 确保目录存在
|
|
|
|
os.makedirs(os.path.dirname(config_path) or '.', exist_ok=True)
|
|
|
|
# 写入默认配置
|
|
|
|
with open(config_path, 'w', encoding='utf-8') as f:
|
|
|
|
yaml.dump(default_config, f, default_flow_style=False, allow_unicode=True)
|
|
|
|
|
|
|
|
def get_config(self) -> Dict[str, Any]:
|
|
|
|
"""获取配置,如果尚未加载则先加载"""
|
|
|
|
if self._config is None:
|
|
|
|
return self.load_config()
|
|
|
|
return self._config
|
|
|
|
|
|
|
|
def get(self, key: str, default=None) -> Any:
|
|
|
|
"""获取指定的配置项,支持使用点号访问嵌套配置"""
|
|
|
|
config = self.get_config()
|
|
|
|
|
|
|
|
# 处理嵌套键访问,如 "sqlite.path"
|
|
|
|
if '.' in key:
|
|
|
|
parts = key.split('.')
|
|
|
|
current = config
|
|
|
|
for part in parts:
|
|
|
|
if part not in current:
|
|
|
|
return default
|
|
|
|
current = current[part]
|
|
|
|
return current
|
|
|
|
|
|
|
|
return config.get(key, default)
|
|
|
|
|
|
|
|
|
|
|
|
# 提供简单的访问函数
|
|
|
|
def get_config_manager() -> ConfigManager:
|
|
|
|
"""获取配置管理器实例"""
|
|
|
|
return ConfigManager()
|