♻️ refactor(config): 重构配置管理系统,使用单例模式替代模块级函数
This commit is contained in:
parent
d324c62da0
commit
64de9058ff
89
config_manager.py
Normal file
89
config_manager.py
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
import os
|
||||||
|
import yaml
|
||||||
|
from typing import Dict, Any
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigManager:
|
||||||
|
"""配置管理器,负责加载和提供配置信息"""
|
||||||
|
|
||||||
|
_instance = None # 单例实例
|
||||||
|
|
||||||
|
def __new__(cls):
|
||||||
|
"""实现单例模式"""
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super(ConfigManager, cls).__new__(cls)
|
||||||
|
cls._instance._config = None
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def load_config(self) -> Dict[str, Any]:
|
||||||
|
"""加载配置,如果配置文件不存在则创建默认配置"""
|
||||||
|
if self._config is not None:
|
||||||
|
return self._config
|
||||||
|
|
||||||
|
config_path = 'config.yaml'
|
||||||
|
|
||||||
|
if not os.path.exists(config_path):
|
||||||
|
# 如果配置文件不存在,创建默认配置
|
||||||
|
ConfigManager._create_default_config(config_path)
|
||||||
|
print(f"已创建新的配置文件 {config_path},请完善其中的配置信息后再运行。")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
# 加载配置文件
|
||||||
|
with open(config_path, 'r', encoding='utf-8') as f:
|
||||||
|
self._config = yaml.safe_load(f)
|
||||||
|
|
||||||
|
return self._config
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _create_default_config(config_path: str) -> None:
|
||||||
|
"""创建默认配置文件"""
|
||||||
|
default_config = {
|
||||||
|
'tushare_token': 'xxxxxxxxxxx',
|
||||||
|
'sqlite': {
|
||||||
|
'path': './data/tushare_data.db',
|
||||||
|
'database_name': 'tushare_data',
|
||||||
|
'table_name': {
|
||||||
|
'moneyflow_ind_dc': 'moneyflow_ind_dc'
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'log': {
|
||||||
|
'level': 'INFO',
|
||||||
|
'store': True,
|
||||||
|
'path': 'logs'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# 确保目录存在
|
||||||
|
os.makedirs(os.path.dirname(config_path) or '.', exist_ok=True)
|
||||||
|
|
||||||
|
# 写入默认配置
|
||||||
|
with open(config_path, 'w', encoding='utf-8') as f:
|
||||||
|
yaml.dump(default_config, f, default_flow_style=False, allow_unicode=True)
|
||||||
|
|
||||||
|
def get_config(self) -> Dict[str, Any]:
|
||||||
|
"""获取配置,如果尚未加载则先加载"""
|
||||||
|
if self._config is None:
|
||||||
|
return self.load_config()
|
||||||
|
return self._config
|
||||||
|
|
||||||
|
def get(self, key: str, default=None) -> Any:
|
||||||
|
"""获取指定的配置项,支持使用点号访问嵌套配置"""
|
||||||
|
config = self.get_config()
|
||||||
|
|
||||||
|
# 处理嵌套键访问,如 "sqlite.path"
|
||||||
|
if '.' in key:
|
||||||
|
parts = key.split('.')
|
||||||
|
current = config
|
||||||
|
for part in parts:
|
||||||
|
if part not in current:
|
||||||
|
return default
|
||||||
|
current = current[part]
|
||||||
|
return current
|
||||||
|
|
||||||
|
return config.get(key, default)
|
||||||
|
|
||||||
|
|
||||||
|
# 提供简单的访问函数
|
||||||
|
def get_config_manager() -> ConfigManager:
|
||||||
|
"""获取配置管理器实例"""
|
||||||
|
return ConfigManager()
|
@ -1,12 +0,0 @@
|
|||||||
tushare_token: xxxxxxxxxxx
|
|
||||||
|
|
||||||
sqlite:
|
|
||||||
path: ./data/tushare_data.db
|
|
||||||
database_name: tushare_data
|
|
||||||
table_name:
|
|
||||||
moneyflow_ind_dc: moneyflow_ind_dc
|
|
||||||
|
|
||||||
log:
|
|
||||||
level: INFO
|
|
||||||
store: true
|
|
||||||
path: logs
|
|
@ -5,9 +5,9 @@ import pandas as pd
|
|||||||
import tushare as ts
|
import tushare as ts
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from config_manager import get_config_manager
|
||||||
from database_manager import DatabaseManager
|
from database_manager import DatabaseManager
|
||||||
from logger import get_logger
|
from logger import get_logger
|
||||||
from utils import load_config
|
|
||||||
|
|
||||||
|
|
||||||
class DataFetcher:
|
class DataFetcher:
|
||||||
@ -17,11 +17,13 @@ class DataFetcher:
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# 加载配置并初始化tushare
|
# 加载配置并初始化tushare
|
||||||
self.config = load_config()
|
self.config = get_config_manager()
|
||||||
ts.set_token(self.config['tushare_token'])
|
ts.set_token(self.config.get('tushare_token'))
|
||||||
self.pro = ts.pro_api()
|
self.pro = ts.pro_api()
|
||||||
|
|
||||||
# 初始化数据库管理器
|
# 初始化数据库管理器
|
||||||
self.db_manager = DatabaseManager()
|
self.db_manager = DatabaseManager()
|
||||||
|
|
||||||
# 获取日志器
|
# 获取日志器
|
||||||
self.logger = get_logger()
|
self.logger = get_logger()
|
||||||
|
|
||||||
|
@ -4,8 +4,8 @@ import traceback
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
from sqlalchemy import create_engine, text, inspect
|
from sqlalchemy import create_engine, text, inspect
|
||||||
|
|
||||||
|
from config_manager import get_config_manager
|
||||||
from logger import get_logger
|
from logger import get_logger
|
||||||
from utils import load_config
|
|
||||||
|
|
||||||
|
|
||||||
class DatabaseManager:
|
class DatabaseManager:
|
||||||
@ -14,7 +14,7 @@ class DatabaseManager:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.config = load_config()
|
self.config = get_config_manager()
|
||||||
self._engine = None
|
self._engine = None
|
||||||
self.logger = get_logger()
|
self.logger = get_logger()
|
||||||
|
|
||||||
@ -22,16 +22,19 @@ class DatabaseManager:
|
|||||||
"""获取SQLite数据库引擎,如果不存在则创建"""
|
"""获取SQLite数据库引擎,如果不存在则创建"""
|
||||||
if self._engine is not None:
|
if self._engine is not None:
|
||||||
return self._engine
|
return self._engine
|
||||||
db_path = self.config['sqlite']['path']
|
|
||||||
|
db_path = self.config.get('sqlite.path')
|
||||||
|
|
||||||
# 确保数据库目录存在
|
# 确保数据库目录存在
|
||||||
os.makedirs(os.path.dirname(db_path), exist_ok=True)
|
os.makedirs(os.path.dirname(db_path), exist_ok=True)
|
||||||
|
|
||||||
# 创建SQLite数据库引擎
|
# 创建SQLite数据库引擎
|
||||||
self._engine = create_engine(f'sqlite:///{db_path}', echo=False)
|
self._engine = create_engine(f'sqlite:///{db_path}', echo=False)
|
||||||
return self._engine
|
return self._engine
|
||||||
|
|
||||||
def get_table_name(self, key):
|
def get_table_name(self, key):
|
||||||
"""根据表键名获取实际表名"""
|
"""根据表键名获取实际表名"""
|
||||||
return self.config['sqlite']['table_name'].get(key, key)
|
return self.config.get(f'sqlite.table_name.{key}', key)
|
||||||
|
|
||||||
def table_exists(self, table_key):
|
def table_exists(self, table_key):
|
||||||
"""
|
"""
|
||||||
|
16
logger.py
16
logger.py
@ -1,12 +1,11 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from utils import load_config
|
from config_manager import get_config_manager
|
||||||
|
|
||||||
|
|
||||||
class Logger:
|
class Logger:
|
||||||
"""日志管理类,负责创建和配置日志器"""
|
"""日志管理类,负责创建和配置日志器"""
|
||||||
|
|
||||||
_instance = None
|
_instance = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -18,7 +17,7 @@ class Logger:
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""初始化日志器配置"""
|
"""初始化日志器配置"""
|
||||||
self.config = load_config()
|
self.config = get_config_manager()
|
||||||
self.logger = logging.getLogger('finance_data')
|
self.logger = logging.getLogger('finance_data')
|
||||||
|
|
||||||
# 设置根日志级别为 DEBUG (最低级别),确保捕获所有级别的日志
|
# 设置根日志级别为 DEBUG (最低级别),确保捕获所有级别的日志
|
||||||
@ -30,8 +29,9 @@ class Logger:
|
|||||||
|
|
||||||
# 控制台处理器 - 使用配置文件中的级别
|
# 控制台处理器 - 使用配置文件中的级别
|
||||||
console_handler = logging.StreamHandler()
|
console_handler = logging.StreamHandler()
|
||||||
|
|
||||||
# 设置日志级别
|
# 设置日志级别
|
||||||
log_level = self.config.get('log', {}).get('level', 'INFO').upper()
|
log_level = self.config.get('log.level', 'INFO').upper()
|
||||||
level_mapping = {
|
level_mapping = {
|
||||||
'DEBUG': logging.DEBUG,
|
'DEBUG': logging.DEBUG,
|
||||||
'INFO': logging.INFO,
|
'INFO': logging.INFO,
|
||||||
@ -40,7 +40,6 @@ class Logger:
|
|||||||
'CRITICAL': logging.CRITICAL
|
'CRITICAL': logging.CRITICAL
|
||||||
}
|
}
|
||||||
console_handler.setLevel(level_mapping.get(log_level, logging.INFO))
|
console_handler.setLevel(level_mapping.get(log_level, logging.INFO))
|
||||||
|
|
||||||
console_handler.setFormatter(logging.Formatter(
|
console_handler.setFormatter(logging.Formatter(
|
||||||
'%(asctime)s - %(levelname)s - %(message)s',
|
'%(asctime)s - %(levelname)s - %(message)s',
|
||||||
datefmt='%Y-%m-%d %H:%M:%S'
|
datefmt='%Y-%m-%d %H:%M:%S'
|
||||||
@ -48,22 +47,21 @@ class Logger:
|
|||||||
self.logger.addHandler(console_handler)
|
self.logger.addHandler(console_handler)
|
||||||
|
|
||||||
# 文件处理器 - 总是保存所有级别的日志
|
# 文件处理器 - 总是保存所有级别的日志
|
||||||
if self.config.get('log', {}).get('store', False):
|
if self.config.get('log.store', False):
|
||||||
log_dir = self.config.get('log', {}).get('path', 'logs')
|
log_dir = self.config.get('log.path', 'logs')
|
||||||
os.makedirs(log_dir, exist_ok=True)
|
os.makedirs(log_dir, exist_ok=True)
|
||||||
|
|
||||||
# 使用当天日期作为文件名
|
# 使用当天日期作为文件名
|
||||||
today = datetime.now().strftime('%Y_%m_%d')
|
today = datetime.now().strftime('%Y_%m_%d')
|
||||||
log_file = os.path.join(log_dir, f'{today}.log')
|
log_file = os.path.join(log_dir, f'{today}.log')
|
||||||
|
|
||||||
file_handler = logging.FileHandler(
|
file_handler = logging.FileHandler(
|
||||||
log_file,
|
log_file,
|
||||||
encoding='utf-8',
|
encoding='utf-8',
|
||||||
mode='a' # 使用追加模式
|
mode='a' # 使用追加模式
|
||||||
)
|
)
|
||||||
|
|
||||||
# 文件处理器始终设置为 DEBUG 级别,保存所有日志
|
# 文件处理器始终设置为 DEBUG 级别,保存所有日志
|
||||||
file_handler.setLevel(logging.DEBUG)
|
file_handler.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
file_handler.setFormatter(logging.Formatter(
|
file_handler.setFormatter(logging.Formatter(
|
||||||
'%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s',
|
'%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s',
|
||||||
datefmt='%Y-%m-%d %H:%M:%S'
|
datefmt='%Y-%m-%d %H:%M:%S'
|
||||||
|
32
utils.py
32
utils.py
@ -1,32 +0,0 @@
|
|||||||
import os
|
|
||||||
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
# 模块级单例
|
|
||||||
_config = None
|
|
||||||
|
|
||||||
|
|
||||||
def load_config():
|
|
||||||
global _config
|
|
||||||
if _config is not None:
|
|
||||||
return _config
|
|
||||||
|
|
||||||
config_path = 'config.yaml'
|
|
||||||
template_path = 'config_template.yaml'
|
|
||||||
|
|
||||||
if not os.path.exists(config_path):
|
|
||||||
if os.path.exists(template_path):
|
|
||||||
# 如果配置文件不存在但模板存在,则复制模板
|
|
||||||
import shutil
|
|
||||||
shutil.copy(template_path, config_path)
|
|
||||||
print(f"已从 {template_path} 创建新的配置文件 {config_path},请完善其中的配置信息后再运行。")
|
|
||||||
exit(1)
|
|
||||||
else:
|
|
||||||
print(f"错误:配置文件 {config_path} 和模板 {template_path} 均不存在。")
|
|
||||||
print(f"请创建 {template_path} 文件后再运行。")
|
|
||||||
exit(1)
|
|
||||||
|
|
||||||
with open(config_path, 'r') as f:
|
|
||||||
_config = yaml.safe_load(f)
|
|
||||||
return _config
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user