import traceback from datetime import datetime, timedelta import pandas as pd import tushare as ts from tqdm import tqdm from config_manager import get_config_manager from database_manager import DatabaseManager from logger_manager import get_logger # 模块级别的配置 config = get_config_manager() ts.set_token(config.get('tushare_token')) pro = ts.pro_api() # 初始化数据库管理器 db_manager = DatabaseManager() # 获取日志器 logger = get_logger() class DataFetcher: """ 数据获取器类,负责从Tushare获取各类数据并管理本地缓存 """ @staticmethod def get_basic(api_name): """ 获取基础数据,如股票列表等 参数: api_name (str): Tushare的API名称,例如'stock_basic' 返回: pandas.DataFrame: 请求的数据 """ # 确保pro中存在对应的API方法 if not hasattr(pro, api_name): logger.error(f"Tushare API '{api_name}'不存在") return False try: df = getattr(pro, api_name)() # 将数据保存到数据库 db_manager.save_df_to_db(df, table_name=api_name, if_exists='replace') return True except Exception as e: logger.error(f"获取基础数据时出错: {e}") logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}") return False @staticmethod def get_trade_date(api_name, start_date=None, end_date=None, force_update=False, batch_size=100000): """ 获取指定时间段内的Tushare数据,使用数据库缓存,并分批处理大量数据 参数: api_name (str): Tushare的API名称,例如'moneyflow'或'moneyflow_ind_dc' start_date (str): 开始日期,格式'YYYYMMDD' end_date (str): 结束日期,格式'YYYYMMDD' force_update (bool): 是否强制更新所选区域数据,默认为False batch_size (int): 每批处理的最大行数,默认为100000 返回: pandas.DataFrame: 请求的数据 """ # 使用api_name作为表key查询表名 table_name = api_name # 确保pro中存在对应的API方法 if not hasattr(pro, api_name): logger.error(f"Tushare API '{api_name}'不存在") return False # 获取目标交易日历 all_trade_dates = DataReader.get_trade_cal(start_date, end_date) # 确定需要获取的日期 if not force_update: # 从数据库获取已有的交易日期 existing_dates = DataReader.get_existing_table_trade_dates(table_name=table_name) # 筛选出需要新获取的日期 dates_to_fetch = [date for date in all_trade_dates if date not in existing_dates] else: # 强制更新,获取所有日期数据 dates_to_fetch = all_trade_dates logger.info(f"强制更新模式: 将更新 {len(dates_to_fetch)} 个交易日的数据") if not dates_to_fetch: logger.info("所有数据已在数据库中,无需更新") return True logger.info(f"需要获取 {len(dates_to_fetch)} 个交易日的数据") # 分批处理数据 temp_batch = [] total_rows = 0 # 获取数据 for trade_date in tqdm(dates_to_fetch): try: # 动态调用Tushare API api_method = getattr(pro, api_name) df = api_method(trade_date=trade_date) if not df.empty: temp_batch.append(df) total_rows += len(df) # 当累积的数据量达到batch_size时,进行一次批量写入 if total_rows >= batch_size: DataFetcher.process_batch(temp_batch, table_name, force_update) logger.info(f"已处理 {total_rows} 行数据") # 重置临时批次 temp_batch = [] total_rows = 0 else: logger.info(f"日期 {trade_date} 无数据") except Exception as e: logger.error(f"获取 {trade_date} 的数据时出错: {e}") logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}") # 处理剩余的数据 if temp_batch: DataFetcher.process_batch(temp_batch, table_name, force_update) logger.info(f"已处理剩余 {total_rows} 行数据") logger.info("数据获取与处理完成") return True @staticmethod def process_batch(batch_dfs, table_name, force_update): """ 处理一批数据 参数: batch_dfs (list): DataFrame列表 table_name (str): 数据库表名称 force_update (bool): 是否强制更新 """ if not batch_dfs: return # 合并批次中的所有DataFrame batch_df = pd.concat(batch_dfs, ignore_index=True) # 保存数据,传入force_update参数 db_manager.save_df_to_db(batch_df, table_name=table_name, force_update=force_update) @staticmethod def update_all(trade_date_api=None, basic_api=None, start_date=None, end_date=None, force_update=False): """ 更新所有数据 """ # 获取所有API名称 if trade_date_api is None: trade_date_api = [ 'moneyflow', # 个股资金流向 'moneyflow_ind_dc', # 东财概念及行业板块资金流向(DC) 'daily', # A股日线行情 'daily_basic', # 每日指标,获取全部股票每日重要的基本面指标 'stk_limit', # 每日涨跌停价格 'cyq_perf', # 每日筹码及胜率 'moneyflow_ths', # 同花顺资金流向 'moneyflow_dc', # 东方财富资金流向 'moneyflow_cnt_ths',# 同花顺概念板块资金流向(THS) 'moneyflow_ind_ths',# 同花顺行业板块资金流向(THS) 'kpl_concept', # 开盘啦题材库,获取开盘啦概念题材列表 'kpl_concept_cons', # 开盘啦题材成分,获取开盘啦概念题材的成分股 'kpl_list', # 获取开盘啦涨停、跌停、炸板等榜单数据 'top_list', # 龙虎榜每日明细 'top_inst', # 龙虎榜机构席位明细 'limit_list_d', # 涨跌停列表(新),获取A股每日涨跌停、炸板数据情况,数据从2020年开始(不提供ST股票的统计) 'ths_daily', # 同花顺板块指数行情 'dc_index', # 东方财富概念板块,获取东方财富每个交易日的概念板块数据,支持按日期查询 'stk_auction', # 当日集合竞价,获取当日个股和ETF的集合竞价成交情况,每天9点25后可以获取当日的集合竞价成交数据 'ths_hot', # 获取同花顺App热榜数据,包括热股、概念板块、ETF、可转债、港美股等等,每日盘中提取4次,收盘后4次,最晚22点提取一次。 ] if basic_api is None: basic_api = [ 'stock_basic', # 股票基本信息 'trade_cal', # 交易日历 'namechange', # 股票曾用名 'ths_index', # 同花顺概念和行业指数 'hm_list', # 游资名录 'index_basic', # 指数基本信息 ] # 使用get_trade_date更新trade_date_api列表中的所有API for api in trade_date_api: logger.info(f"更新API: {api}") DataFetcher.get_trade_date(api_name=api, start_date=start_date, end_date=end_date, force_update=force_update) # 使用get_basic更新basic_api列表中的所有API for api in basic_api: logger.info(f"更新API: {api}") DataFetcher.get_basic(api_name=api) class DataReader: """ 数据读取器类,负责从数据库读取数据 """ @staticmethod def get_trade_cal(start_date=None, end_date=None, update=False): """ 从数据库获取交易日历,如果表不存在则自动拉取数据,并截取指定日期范围内的交易日历(工作日) 参数: start_date (str): 开始日期,格式'YYYYMMDD' end_date (str): 结束日期,格式'YYYYMMDD' update (bool): 是否更新交易日历,默认为False 返回: list: 交易日期列表 """ # 先检查表是否存在 if not db_manager.table_exists('trade_cal') or update: logger.debug(f"表 trade_cal 不存在") # 自动拉取交易日历 try: DataFetcher.get_basic('trade_cal') except Exception as e: logger.error(f"自动拉取交易日历时出错: {e}") logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}") return [] if start_date is None: start_date = (datetime.now() - timedelta(days=30)).strftime('%Y%m%d') if end_date is None: end_date = datetime.now().strftime('%Y%m%d') try: df = db_manager.load_df_from_db('trade_cal', conditions=f"cal_date BETWEEN '{start_date}' AND '{end_date}'") return df[df['is_open'] == 1]['cal_date'].tolist() except Exception as e: logger.error(f"获取交易日历时出错: {e}") logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}") return [] @staticmethod def get_existing_table_trade_dates(table_name): """ 查询指定表中已存在的交易日期 参数: table_name (str): 数据表名称 返回: set: 已存在于数据库中的交易日期集合 """ # 先检查表是否存在 if not db_manager.table_exists(table_name): logger.debug(f"表 '{table_name}' 不存在") return [] try: # 使用query方法获取不重复的交易日期 query_result = db_manager.query(f"SELECT DISTINCT trade_date FROM {table_name}") # 将查询结果转换为集合并返回 return list(set(query_result['trade_date'].values)) except Exception as e: logger.error(f"获取已存在交易日期时出错: {e}") logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}") return []