diff --git a/config_template.yaml b/config_template.yaml index 1a25c15..677081a 100644 --- a/config_template.yaml +++ b/config_template.yaml @@ -1,7 +1,12 @@ -tushare_token: +tushare_token: xxxxxxxxxxx sqlite: - path: - database_name: + path: ./data/tushare_data.db + database_name: tushare_data table_name: - moneyflow_ind_dc: \ No newline at end of file + moneyflow_ind_dc: moneyflow_ind_dc + +log: + level: INFO + store: true + path: logs \ No newline at end of file diff --git a/data_fetcher.py b/data_fetcher.py index 3792612..9cbc6c5 100644 --- a/data_fetcher.py +++ b/data_fetcher.py @@ -1,9 +1,13 @@ -import pandas as pd -from tqdm import tqdm -import tushare as ts +import traceback from datetime import datetime, timedelta -from utils import load_config + +import pandas as pd +import tushare as ts +from tqdm import tqdm + from database_manager import DatabaseManager +from logger import get_logger +from utils import load_config class DataFetcher: @@ -18,15 +22,15 @@ class DataFetcher: self.pro = ts.pro_api() # 初始化数据库管理器 self.db_manager = DatabaseManager() + # 获取日志器 + self.logger = get_logger() def get_trade_cal(self, start_date=None, end_date=None): """ 获取指定时间段内的交易日历 - 参数: start_date (str): 开始日期,格式'YYYYMMDD' end_date (str): 结束日期,格式'YYYYMMDD' - 返回: list: 交易日期列表 """ @@ -34,23 +38,21 @@ class DataFetcher: start_date = (datetime.now() - timedelta(days=30)).strftime('%Y%m%d') if end_date is None: end_date = datetime.now().strftime('%Y%m%d') - try: trade_cal_df = self.pro.trade_cal(exchange='', start_date=start_date, end_date=end_date) return trade_cal_df[trade_cal_df['is_open'] == 1]['cal_date'].tolist() except Exception as e: - print(f"获取交易日历时出错: {e}") + self.logger.error(f"获取交易日历时出错: {e}") + self.logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}") return [] def get_moneyflow_ind_dc(self, start_date=None, end_date=None, force_update=False): """ 获取指定时间段内的板块资金流向数据,使用数据库缓存 - 参数: start_date (str): 开始日期,格式'YYYYMMDD' end_date (str): 结束日期,格式'YYYYMMDD' force_update (bool): 是否强制更新所选区域数据,默认为False - 返回: pandas.DataFrame: 所有板块资金流向数据 """ @@ -66,13 +68,13 @@ class DataFetcher: else: # 强制更新,获取所有日期数据 dates_to_fetch = all_trade_dates - print(f"强制更新模式: 将更新 {len(dates_to_fetch)} 个交易日的数据") + self.logger.info(f"强制更新模式: 将更新 {len(dates_to_fetch)} 个交易日的数据") if not dates_to_fetch: - print("所有数据已在数据库中,无需更新") + self.logger.info("所有数据已在数据库中,无需更新") return self.db_manager.load_df_from_db(table_key='moneyflow_ind_dc') - print(f"需要获取 {len(dates_to_fetch)} 个交易日的数据") + self.logger.info(f"需要获取 {len(dates_to_fetch)} 个交易日的数据") # 获取数据 all_new_data = [] @@ -83,15 +85,15 @@ class DataFetcher: if not df.empty: all_new_data.append(df) else: - print(f"日期 {trade_date} 无数据") + self.logger.info(f"日期 {trade_date} 无数据") except Exception as e: - print(f"获取 {trade_date} 的数据时出错: {e}") + self.logger.error(f"获取 {trade_date} 的数据时出错: {e}") + self.logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}") # 处理新数据 if all_new_data: # 将所有新数据合并为一个DataFrame new_df = pd.concat(all_new_data, ignore_index=True) - if force_update: # 强制更新模式:需要删除已有的日期数据,然后重新插入 existing_df = self.db_manager.load_df_from_db(table_key='moneyflow_ind_dc') @@ -101,12 +103,12 @@ class DataFetcher: final_df = pd.concat([filtered_df, new_df], ignore_index=True) # 替换整个表 self.db_manager.save_df_to_db(final_df, table_key='moneyflow_ind_dc', if_exists='replace') - print(f"已强制更新 {len(new_df)} 条记录到数据库") + self.logger.info(f"已强制更新 {len(new_df)} 条记录到数据库") else: # 普通追加模式 self.db_manager.save_df_to_db(new_df, table_key='moneyflow_ind_dc', if_exists='append') - print(f"已将 {len(new_df)} 条新记录追加到数据库") + self.logger.info(f"已将 {len(new_df)} 条新记录追加到数据库") else: - print("未获取到任何新数据") + self.logger.info("未获取到任何新数据") return self.db_manager.load_df_from_db(table_key='moneyflow_ind_dc') \ No newline at end of file diff --git a/database_manager.py b/database_manager.py index aabe815..8f5f4d2 100644 --- a/database_manager.py +++ b/database_manager.py @@ -1,8 +1,10 @@ import os +import traceback import pandas as pd from sqlalchemy import create_engine, text, inspect +from logger import get_logger from utils import load_config @@ -14,12 +16,12 @@ class DatabaseManager: def __init__(self): self.config = load_config() self._engine = None + self.logger = get_logger() def get_engine(self): """获取SQLite数据库引擎,如果不存在则创建""" if self._engine is not None: return self._engine - db_path = self.config['sqlite']['path'] # 确保数据库目录存在 os.makedirs(os.path.dirname(db_path), exist_ok=True) @@ -34,10 +36,8 @@ class DatabaseManager: def table_exists(self, table_key): """ 检查表是否存在 - 参数: table_key (str): 数据表键名 - 返回: bool: 表是否存在 """ @@ -49,16 +49,14 @@ class DatabaseManager: def get_existing_trade_dates(self, table_key): """ 从数据库中获取已有的交易日期 - 参数: table_key (str): 数据表键名 - 返回: set: 已存在于数据库中的交易日期集合 """ # 先检查表是否存在 if not self.table_exists(table_key): - print(f"表 '{self.get_table_name(table_key)}' 不存在,返回空集合") + self.logger.debug(f"表 '{self.get_table_name(table_key)}' 不存在") return set() table_name = self.get_table_name(table_key) @@ -69,17 +67,16 @@ class DatabaseManager: result = connection.execute(text(query)) return {row[0] for row in result} except Exception as e: - print(f"获取已存在交易日期时出错: {e}") + self.logger.error(f"获取已存在交易日期时出错: {e}") + self.logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}") return set() def load_df_from_db(self, table_key, conditions=None): """ 从数据库中加载数据 - 参数: table_key (str): 表键名 conditions (str): 过滤条件,如 "trade_date > '20230101'" - 返回: pandas.DataFrame: 查询结果 """ @@ -88,35 +85,33 @@ class DatabaseManager: query = f"SELECT * FROM {table_name}" if conditions: query += f" WHERE {conditions}" - try: return pd.read_sql(query, engine) except Exception as e: - print(f"从数据库加载数据时出错: {e}") + self.logger.error(f"从数据库加载数据时出错: {e}") + self.logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}") return pd.DataFrame() def save_df_to_db(self, df, table_key, if_exists='append'): """ 保存DataFrame到数据库 - 参数: df (pandas.DataFrame): 要保存的数据 table_key (str): 表键名 if_exists (str): 如果表存在时的操作: 'fail', 'replace', 或 'append' - 返回: bool: 操作是否成功 """ if df.empty: - print("警告: 尝试保存空的DataFrame到数据库") + self.logger.warning("警告: 尝试保存空的DataFrame到数据库") return False table_name = self.get_table_name(table_key) engine = self.get_engine() - try: df.to_sql(table_name, engine, if_exists=if_exists, index=False) return True except Exception as e: - print(f"保存数据到数据库时出错: {e}") + self.logger.error(f"保存数据到数据库时出错: {e}") + self.logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}") return False \ No newline at end of file diff --git a/logger.py b/logger.py new file mode 100644 index 0000000..5068aed --- /dev/null +++ b/logger.py @@ -0,0 +1,77 @@ +import logging +import os +from datetime import datetime +from utils import load_config + + +class Logger: + """日志管理类,负责创建和配置日志器""" + + _instance = None + + @classmethod + def get_logger(cls): + """获取日志器单例""" + if cls._instance is None: + cls._instance = Logger() + return cls._instance.logger + + def __init__(self): + """初始化日志器配置""" + self.config = load_config() + self.logger = logging.getLogger('finance_data') + + # 设置根日志级别为 DEBUG (最低级别),确保捕获所有级别的日志 + self.logger.setLevel(logging.DEBUG) + + # 清除已有的处理器 + if self.logger.handlers: + self.logger.handlers.clear() + + # 控制台处理器 - 使用配置文件中的级别 + console_handler = logging.StreamHandler() + # 设置日志级别 + log_level = self.config.get('log', {}).get('level', 'INFO').upper() + level_mapping = { + 'DEBUG': logging.DEBUG, + 'INFO': logging.INFO, + 'WARNING': logging.WARNING, + 'ERROR': logging.ERROR, + 'CRITICAL': logging.CRITICAL + } + console_handler.setLevel(level_mapping.get(log_level, logging.INFO)) + + console_handler.setFormatter(logging.Formatter( + '%(asctime)s - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' + )) + self.logger.addHandler(console_handler) + + # 文件处理器 - 总是保存所有级别的日志 + if self.config.get('log', {}).get('store', False): + log_dir = self.config.get('log', {}).get('path', 'logs') + os.makedirs(log_dir, exist_ok=True) + + # 使用当天日期作为文件名 + today = datetime.now().strftime('%Y_%m_%d') + log_file = os.path.join(log_dir, f'{today}.log') + + file_handler = logging.FileHandler( + log_file, + encoding='utf-8', + mode='a' # 使用追加模式 + ) + # 文件处理器始终设置为 DEBUG 级别,保存所有日志 + file_handler.setLevel(logging.DEBUG) + + file_handler.setFormatter(logging.Formatter( + '%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' + )) + self.logger.addHandler(file_handler) + + +# 提供全局访问点 +def get_logger(): + """获取日志器实例""" + return Logger.get_logger() \ No newline at end of file diff --git a/utils.py b/utils.py index 900b271..6859e74 100644 --- a/utils.py +++ b/utils.py @@ -4,7 +4,6 @@ import yaml # 模块级单例 _config = None -_engine = None def load_config():