diff --git a/config_manager.py b/config_manager.py new file mode 100644 index 0000000..e16036e --- /dev/null +++ b/config_manager.py @@ -0,0 +1,89 @@ +import os +import yaml +from typing import Dict, Any + + +class ConfigManager: + """配置管理器,负责加载和提供配置信息""" + + _instance = None # 单例实例 + + def __new__(cls): + """实现单例模式""" + if cls._instance is None: + cls._instance = super(ConfigManager, cls).__new__(cls) + cls._instance._config = None + return cls._instance + + def load_config(self) -> Dict[str, Any]: + """加载配置,如果配置文件不存在则创建默认配置""" + if self._config is not None: + return self._config + + config_path = 'config.yaml' + + if not os.path.exists(config_path): + # 如果配置文件不存在,创建默认配置 + ConfigManager._create_default_config(config_path) + print(f"已创建新的配置文件 {config_path},请完善其中的配置信息后再运行。") + exit(1) + + # 加载配置文件 + with open(config_path, 'r', encoding='utf-8') as f: + self._config = yaml.safe_load(f) + + return self._config + + @staticmethod + def _create_default_config(config_path: str) -> None: + """创建默认配置文件""" + default_config = { + 'tushare_token': 'xxxxxxxxxxx', + 'sqlite': { + 'path': './data/tushare_data.db', + 'database_name': 'tushare_data', + 'table_name': { + 'moneyflow_ind_dc': 'moneyflow_ind_dc' + } + }, + 'log': { + 'level': 'INFO', + 'store': True, + 'path': 'logs' + } + } + + # 确保目录存在 + os.makedirs(os.path.dirname(config_path) or '.', exist_ok=True) + + # 写入默认配置 + with open(config_path, 'w', encoding='utf-8') as f: + yaml.dump(default_config, f, default_flow_style=False, allow_unicode=True) + + def get_config(self) -> Dict[str, Any]: + """获取配置,如果尚未加载则先加载""" + if self._config is None: + return self.load_config() + return self._config + + def get(self, key: str, default=None) -> Any: + """获取指定的配置项,支持使用点号访问嵌套配置""" + config = self.get_config() + + # 处理嵌套键访问,如 "sqlite.path" + if '.' in key: + parts = key.split('.') + current = config + for part in parts: + if part not in current: + return default + current = current[part] + return current + + return config.get(key, default) + + +# 提供简单的访问函数 +def get_config_manager() -> ConfigManager: + """获取配置管理器实例""" + return ConfigManager() \ No newline at end of file diff --git a/config_template.yaml b/config_template.yaml deleted file mode 100644 index 677081a..0000000 --- a/config_template.yaml +++ /dev/null @@ -1,12 +0,0 @@ -tushare_token: xxxxxxxxxxx - -sqlite: - path: ./data/tushare_data.db - database_name: tushare_data - table_name: - moneyflow_ind_dc: moneyflow_ind_dc - -log: - level: INFO - store: true - path: logs \ No newline at end of file diff --git a/data_fetcher.py b/data_fetcher.py index 9cbc6c5..f140d17 100644 --- a/data_fetcher.py +++ b/data_fetcher.py @@ -5,9 +5,9 @@ import pandas as pd import tushare as ts from tqdm import tqdm +from config_manager import get_config_manager from database_manager import DatabaseManager from logger import get_logger -from utils import load_config class DataFetcher: @@ -17,11 +17,13 @@ class DataFetcher: def __init__(self): # 加载配置并初始化tushare - self.config = load_config() - ts.set_token(self.config['tushare_token']) + self.config = get_config_manager() + ts.set_token(self.config.get('tushare_token')) self.pro = ts.pro_api() + # 初始化数据库管理器 self.db_manager = DatabaseManager() + # 获取日志器 self.logger = get_logger() diff --git a/database_manager.py b/database_manager.py index 8f5f4d2..8d359f4 100644 --- a/database_manager.py +++ b/database_manager.py @@ -4,8 +4,8 @@ import traceback import pandas as pd from sqlalchemy import create_engine, text, inspect +from config_manager import get_config_manager from logger import get_logger -from utils import load_config class DatabaseManager: @@ -14,7 +14,7 @@ class DatabaseManager: """ def __init__(self): - self.config = load_config() + self.config = get_config_manager() self._engine = None self.logger = get_logger() @@ -22,16 +22,19 @@ class DatabaseManager: """获取SQLite数据库引擎,如果不存在则创建""" if self._engine is not None: return self._engine - db_path = self.config['sqlite']['path'] + + db_path = self.config.get('sqlite.path') + # 确保数据库目录存在 os.makedirs(os.path.dirname(db_path), exist_ok=True) + # 创建SQLite数据库引擎 self._engine = create_engine(f'sqlite:///{db_path}', echo=False) return self._engine def get_table_name(self, key): """根据表键名获取实际表名""" - return self.config['sqlite']['table_name'].get(key, key) + return self.config.get(f'sqlite.table_name.{key}', key) def table_exists(self, table_key): """ diff --git a/logger.py b/logger.py index 5068aed..50c58d5 100644 --- a/logger.py +++ b/logger.py @@ -1,12 +1,11 @@ import logging import os from datetime import datetime -from utils import load_config +from config_manager import get_config_manager class Logger: """日志管理类,负责创建和配置日志器""" - _instance = None @classmethod @@ -18,7 +17,7 @@ class Logger: def __init__(self): """初始化日志器配置""" - self.config = load_config() + self.config = get_config_manager() self.logger = logging.getLogger('finance_data') # 设置根日志级别为 DEBUG (最低级别),确保捕获所有级别的日志 @@ -30,8 +29,9 @@ class Logger: # 控制台处理器 - 使用配置文件中的级别 console_handler = logging.StreamHandler() + # 设置日志级别 - log_level = self.config.get('log', {}).get('level', 'INFO').upper() + log_level = self.config.get('log.level', 'INFO').upper() level_mapping = { 'DEBUG': logging.DEBUG, 'INFO': logging.INFO, @@ -40,7 +40,6 @@ class Logger: 'CRITICAL': logging.CRITICAL } console_handler.setLevel(level_mapping.get(log_level, logging.INFO)) - console_handler.setFormatter(logging.Formatter( '%(asctime)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S' @@ -48,22 +47,21 @@ class Logger: self.logger.addHandler(console_handler) # 文件处理器 - 总是保存所有级别的日志 - if self.config.get('log', {}).get('store', False): - log_dir = self.config.get('log', {}).get('path', 'logs') + if self.config.get('log.store', False): + log_dir = self.config.get('log.path', 'logs') os.makedirs(log_dir, exist_ok=True) # 使用当天日期作为文件名 today = datetime.now().strftime('%Y_%m_%d') log_file = os.path.join(log_dir, f'{today}.log') - file_handler = logging.FileHandler( log_file, encoding='utf-8', mode='a' # 使用追加模式 ) + # 文件处理器始终设置为 DEBUG 级别,保存所有日志 file_handler.setLevel(logging.DEBUG) - file_handler.setFormatter(logging.Formatter( '%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s', datefmt='%Y-%m-%d %H:%M:%S' diff --git a/utils.py b/utils.py deleted file mode 100644 index 6859e74..0000000 --- a/utils.py +++ /dev/null @@ -1,32 +0,0 @@ -import os - -import yaml - -# 模块级单例 -_config = None - - -def load_config(): - global _config - if _config is not None: - return _config - - config_path = 'config.yaml' - template_path = 'config_template.yaml' - - if not os.path.exists(config_path): - if os.path.exists(template_path): - # 如果配置文件不存在但模板存在,则复制模板 - import shutil - shutil.copy(template_path, config_path) - print(f"已从 {template_path} 创建新的配置文件 {config_path},请完善其中的配置信息后再运行。") - exit(1) - else: - print(f"错误:配置文件 {config_path} 和模板 {template_path} 均不存在。") - print(f"请创建 {template_path} 文件后再运行。") - exit(1) - - with open(config_path, 'r') as f: - _config = yaml.safe_load(f) - return _config -