♻️ refactor(config): 重构配置管理系统,使用单例模式替代模块级函数
This commit is contained in:
parent
d324c62da0
commit
64de9058ff
89
config_manager.py
Normal file
89
config_manager.py
Normal file
@ -0,0 +1,89 @@
|
||||
import os
|
||||
import yaml
|
||||
from typing import Dict, Any
|
||||
|
||||
|
||||
class ConfigManager:
|
||||
"""配置管理器,负责加载和提供配置信息"""
|
||||
|
||||
_instance = None # 单例实例
|
||||
|
||||
def __new__(cls):
|
||||
"""实现单例模式"""
|
||||
if cls._instance is None:
|
||||
cls._instance = super(ConfigManager, cls).__new__(cls)
|
||||
cls._instance._config = None
|
||||
return cls._instance
|
||||
|
||||
def load_config(self) -> Dict[str, Any]:
|
||||
"""加载配置,如果配置文件不存在则创建默认配置"""
|
||||
if self._config is not None:
|
||||
return self._config
|
||||
|
||||
config_path = 'config.yaml'
|
||||
|
||||
if not os.path.exists(config_path):
|
||||
# 如果配置文件不存在,创建默认配置
|
||||
ConfigManager._create_default_config(config_path)
|
||||
print(f"已创建新的配置文件 {config_path},请完善其中的配置信息后再运行。")
|
||||
exit(1)
|
||||
|
||||
# 加载配置文件
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
self._config = yaml.safe_load(f)
|
||||
|
||||
return self._config
|
||||
|
||||
@staticmethod
|
||||
def _create_default_config(config_path: str) -> None:
|
||||
"""创建默认配置文件"""
|
||||
default_config = {
|
||||
'tushare_token': 'xxxxxxxxxxx',
|
||||
'sqlite': {
|
||||
'path': './data/tushare_data.db',
|
||||
'database_name': 'tushare_data',
|
||||
'table_name': {
|
||||
'moneyflow_ind_dc': 'moneyflow_ind_dc'
|
||||
}
|
||||
},
|
||||
'log': {
|
||||
'level': 'INFO',
|
||||
'store': True,
|
||||
'path': 'logs'
|
||||
}
|
||||
}
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(os.path.dirname(config_path) or '.', exist_ok=True)
|
||||
|
||||
# 写入默认配置
|
||||
with open(config_path, 'w', encoding='utf-8') as f:
|
||||
yaml.dump(default_config, f, default_flow_style=False, allow_unicode=True)
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
"""获取配置,如果尚未加载则先加载"""
|
||||
if self._config is None:
|
||||
return self.load_config()
|
||||
return self._config
|
||||
|
||||
def get(self, key: str, default=None) -> Any:
|
||||
"""获取指定的配置项,支持使用点号访问嵌套配置"""
|
||||
config = self.get_config()
|
||||
|
||||
# 处理嵌套键访问,如 "sqlite.path"
|
||||
if '.' in key:
|
||||
parts = key.split('.')
|
||||
current = config
|
||||
for part in parts:
|
||||
if part not in current:
|
||||
return default
|
||||
current = current[part]
|
||||
return current
|
||||
|
||||
return config.get(key, default)
|
||||
|
||||
|
||||
# 提供简单的访问函数
|
||||
def get_config_manager() -> ConfigManager:
|
||||
"""获取配置管理器实例"""
|
||||
return ConfigManager()
|
@ -1,12 +0,0 @@
|
||||
tushare_token: xxxxxxxxxxx
|
||||
|
||||
sqlite:
|
||||
path: ./data/tushare_data.db
|
||||
database_name: tushare_data
|
||||
table_name:
|
||||
moneyflow_ind_dc: moneyflow_ind_dc
|
||||
|
||||
log:
|
||||
level: INFO
|
||||
store: true
|
||||
path: logs
|
@ -5,9 +5,9 @@ import pandas as pd
|
||||
import tushare as ts
|
||||
from tqdm import tqdm
|
||||
|
||||
from config_manager import get_config_manager
|
||||
from database_manager import DatabaseManager
|
||||
from logger import get_logger
|
||||
from utils import load_config
|
||||
|
||||
|
||||
class DataFetcher:
|
||||
@ -17,11 +17,13 @@ class DataFetcher:
|
||||
|
||||
def __init__(self):
|
||||
# 加载配置并初始化tushare
|
||||
self.config = load_config()
|
||||
ts.set_token(self.config['tushare_token'])
|
||||
self.config = get_config_manager()
|
||||
ts.set_token(self.config.get('tushare_token'))
|
||||
self.pro = ts.pro_api()
|
||||
|
||||
# 初始化数据库管理器
|
||||
self.db_manager = DatabaseManager()
|
||||
|
||||
# 获取日志器
|
||||
self.logger = get_logger()
|
||||
|
||||
|
@ -4,8 +4,8 @@ import traceback
|
||||
import pandas as pd
|
||||
from sqlalchemy import create_engine, text, inspect
|
||||
|
||||
from config_manager import get_config_manager
|
||||
from logger import get_logger
|
||||
from utils import load_config
|
||||
|
||||
|
||||
class DatabaseManager:
|
||||
@ -14,7 +14,7 @@ class DatabaseManager:
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.config = load_config()
|
||||
self.config = get_config_manager()
|
||||
self._engine = None
|
||||
self.logger = get_logger()
|
||||
|
||||
@ -22,16 +22,19 @@ class DatabaseManager:
|
||||
"""获取SQLite数据库引擎,如果不存在则创建"""
|
||||
if self._engine is not None:
|
||||
return self._engine
|
||||
db_path = self.config['sqlite']['path']
|
||||
|
||||
db_path = self.config.get('sqlite.path')
|
||||
|
||||
# 确保数据库目录存在
|
||||
os.makedirs(os.path.dirname(db_path), exist_ok=True)
|
||||
|
||||
# 创建SQLite数据库引擎
|
||||
self._engine = create_engine(f'sqlite:///{db_path}', echo=False)
|
||||
return self._engine
|
||||
|
||||
def get_table_name(self, key):
|
||||
"""根据表键名获取实际表名"""
|
||||
return self.config['sqlite']['table_name'].get(key, key)
|
||||
return self.config.get(f'sqlite.table_name.{key}', key)
|
||||
|
||||
def table_exists(self, table_key):
|
||||
"""
|
||||
|
16
logger.py
16
logger.py
@ -1,12 +1,11 @@
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
from utils import load_config
|
||||
from config_manager import get_config_manager
|
||||
|
||||
|
||||
class Logger:
|
||||
"""日志管理类,负责创建和配置日志器"""
|
||||
|
||||
_instance = None
|
||||
|
||||
@classmethod
|
||||
@ -18,7 +17,7 @@ class Logger:
|
||||
|
||||
def __init__(self):
|
||||
"""初始化日志器配置"""
|
||||
self.config = load_config()
|
||||
self.config = get_config_manager()
|
||||
self.logger = logging.getLogger('finance_data')
|
||||
|
||||
# 设置根日志级别为 DEBUG (最低级别),确保捕获所有级别的日志
|
||||
@ -30,8 +29,9 @@ class Logger:
|
||||
|
||||
# 控制台处理器 - 使用配置文件中的级别
|
||||
console_handler = logging.StreamHandler()
|
||||
|
||||
# 设置日志级别
|
||||
log_level = self.config.get('log', {}).get('level', 'INFO').upper()
|
||||
log_level = self.config.get('log.level', 'INFO').upper()
|
||||
level_mapping = {
|
||||
'DEBUG': logging.DEBUG,
|
||||
'INFO': logging.INFO,
|
||||
@ -40,7 +40,6 @@ class Logger:
|
||||
'CRITICAL': logging.CRITICAL
|
||||
}
|
||||
console_handler.setLevel(level_mapping.get(log_level, logging.INFO))
|
||||
|
||||
console_handler.setFormatter(logging.Formatter(
|
||||
'%(asctime)s - %(levelname)s - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
@ -48,22 +47,21 @@ class Logger:
|
||||
self.logger.addHandler(console_handler)
|
||||
|
||||
# 文件处理器 - 总是保存所有级别的日志
|
||||
if self.config.get('log', {}).get('store', False):
|
||||
log_dir = self.config.get('log', {}).get('path', 'logs')
|
||||
if self.config.get('log.store', False):
|
||||
log_dir = self.config.get('log.path', 'logs')
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
|
||||
# 使用当天日期作为文件名
|
||||
today = datetime.now().strftime('%Y_%m_%d')
|
||||
log_file = os.path.join(log_dir, f'{today}.log')
|
||||
|
||||
file_handler = logging.FileHandler(
|
||||
log_file,
|
||||
encoding='utf-8',
|
||||
mode='a' # 使用追加模式
|
||||
)
|
||||
|
||||
# 文件处理器始终设置为 DEBUG 级别,保存所有日志
|
||||
file_handler.setLevel(logging.DEBUG)
|
||||
|
||||
file_handler.setFormatter(logging.Formatter(
|
||||
'%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
|
32
utils.py
32
utils.py
@ -1,32 +0,0 @@
|
||||
import os
|
||||
|
||||
import yaml
|
||||
|
||||
# 模块级单例
|
||||
_config = None
|
||||
|
||||
|
||||
def load_config():
|
||||
global _config
|
||||
if _config is not None:
|
||||
return _config
|
||||
|
||||
config_path = 'config.yaml'
|
||||
template_path = 'config_template.yaml'
|
||||
|
||||
if not os.path.exists(config_path):
|
||||
if os.path.exists(template_path):
|
||||
# 如果配置文件不存在但模板存在,则复制模板
|
||||
import shutil
|
||||
shutil.copy(template_path, config_path)
|
||||
print(f"已从 {template_path} 创建新的配置文件 {config_path},请完善其中的配置信息后再运行。")
|
||||
exit(1)
|
||||
else:
|
||||
print(f"错误:配置文件 {config_path} 和模板 {template_path} 均不存在。")
|
||||
print(f"请创建 {template_path} 文件后再运行。")
|
||||
exit(1)
|
||||
|
||||
with open(config_path, 'r') as f:
|
||||
_config = yaml.safe_load(f)
|
||||
return _config
|
||||
|
Loading…
Reference in New Issue
Block a user