♻️ refactor(data_fetcher): 重构数据获取逻辑,添加通用API调用接口和批量处理功能,提高大数据集处理效率

This commit is contained in:
Qihang Zhang 2025-04-19 16:51:53 +08:00
parent 8df625502d
commit c0468c9f71
3 changed files with 106 additions and 34 deletions

View File

@ -48,19 +48,30 @@ class DataFetcher:
self.logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
return []
def get_moneyflow_ind_dc(self, start_date=None, end_date=None, force_update=False):
def get(self, api_name, start_date=None, end_date=None, force_update=False, batch_size=100000):
"""
获取指定时间段内的板块资金流向数据使用数据库缓存
获取指定时间段内的Tushare数据使用数据库缓存并分批处理大量数据
参数:
api_name (str): Tushare的API名称例如'moneyflow''moneyflow_ind_dc'
start_date (str): 开始日期格式'YYYYMMDD'
end_date (str): 结束日期格式'YYYYMMDD'
force_update (bool): 是否强制更新所选区域数据默认为False
batch_size (int): 每批处理的最大行数默认为100000
返回:
pandas.DataFrame: 所有板块资金流向数据
pandas.DataFrame: 请求的数据
"""
# 使用api_name作为表key查询表名
table_key = api_name
# 确保self.pro中存在对应的API方法
if not hasattr(self.pro, api_name):
self.logger.error(f"Tushare API '{api_name}'不存在")
return pd.DataFrame()
# 获取目标交易日历
all_trade_dates = self.get_trade_cal(start_date, end_date)
table_key = 'moneyflow_ind_dc'
# 确定需要获取的日期
if not force_update:
@ -79,39 +90,65 @@ class DataFetcher:
self.logger.info(f"需要获取 {len(dates_to_fetch)} 个交易日的数据")
# 分批处理数据
temp_batch = []
total_rows = 0
# 获取数据
all_new_data = []
for trade_date in tqdm(dates_to_fetch):
try:
# 从tushare获取当日板块资金流向数据
df = self.pro.moneyflow_ind_dc(trade_date=trade_date)
# 动态调用Tushare API
api_method = getattr(self.pro, api_name)
df = api_method(trade_date=trade_date)
if not df.empty:
all_new_data.append(df)
temp_batch.append(df)
total_rows += len(df)
# 当累积的数据量达到batch_size时进行一次批量写入
if total_rows >= batch_size:
self._process_batch(temp_batch, table_key, force_update)
self.logger.info(f"已处理 {total_rows} 行数据")
# 重置临时批次
temp_batch = []
total_rows = 0
else:
self.logger.info(f"日期 {trade_date} 无数据")
except Exception as e:
self.logger.error(f"获取 {trade_date} 的数据时出错: {e}")
self.logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
# 处理新数据
if all_new_data:
# 将所有新数据合并为一个DataFrame
new_df = pd.concat(all_new_data, ignore_index=True)
if force_update:
# 强制更新模式:需要删除已有的日期数据,然后重新插入
existing_df = self.db_manager.load_df_from_db(table_key=table_key)
# 过滤掉需要更新的日期范围内的数据
filtered_df = existing_df[~existing_df['trade_date'].isin(dates_to_fetch)]
# 拼接新数据
final_df = pd.concat([filtered_df, new_df], ignore_index=True)
# 替换整个表
self.db_manager.save_df_to_db(final_df, table_key=table_key, if_exists='replace')
self.logger.info(f"已强制更新 {len(new_df)} 条记录到数据库")
else:
# 普通追加模式
self.db_manager.save_df_to_db(new_df, table_key=table_key, if_exists='append')
self.logger.info(f"已将 {len(new_df)} 条新记录追加到数据库")
else:
self.logger.info("未获取到任何新数据")
# 处理剩余的数据
if temp_batch:
self._process_batch(temp_batch, table_key, force_update)
self.logger.info(f"已处理剩余 {total_rows} 行数据")
return self.db_manager.load_df_from_db(table_key=table_key)
self.logger.info("数据获取与处理完成")
return self.db_manager.load_df_from_db(table_key=table_key)
def _process_batch(self, batch_dfs, table_key, force_update):
"""
处理一批数据
参数:
batch_dfs (list): DataFrame列表
table_key (str): 数据库表名
force_update (bool): 是否强制更新
"""
if not batch_dfs:
return
# 合并批次中的所有DataFrame
batch_df = pd.concat(batch_dfs, ignore_index=True)
if force_update:
# 强制更新模式:先删除当前批次涉及的日期数据,然后插入新数据
current_dates = batch_df['trade_date'].unique().tolist()
# 删除这些日期的现有数据
self.db_manager.delete_existing_data_by_dates(table_key, current_dates)
# 插入新数据
self.db_manager.save_df_to_db(batch_df, table_key=table_key, if_exists='append')
else:
# 普通追加模式
self.db_manager.save_df_to_db(batch_df, table_key=table_key, if_exists='append')

View File

@ -117,4 +117,39 @@ class DatabaseManager:
except Exception as e:
self.logger.error(f"保存数据到数据库时出错: {e}")
self.logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
return False
return False
def delete_existing_data_by_dates(self, table_key, trade_dates):
"""
从数据库表中删除指定交易日期的数据
参数:
table_key (str): 数据表键名
trade_dates (list): 需要删除的交易日期列表
返回:
bool: 操作是否成功
"""
if not trade_dates:
return True # 如果没有日期需要删除,认为操作成功
table_name = self.get_table_name(table_key)
engine = self.get_engine()
try:
# 将日期列表转换为SQL安全的格式
date_strings = [f"'{date}'" for date in trade_dates]
dates_clause = ", ".join(date_strings)
delete_query = f"DELETE FROM {table_name} WHERE trade_date IN ({dates_clause})"
with engine.connect() as connection:
connection.execute(text(delete_query))
connection.commit()
self.logger.info(f"成功从表 '{table_name}' 中删除 {len(trade_dates)} 个日期的数据")
return True
except Exception as e:
self.logger.error(f"删除表 '{table_name}' 中的数据时出错: {e}")
self.logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
return False

View File

@ -3,7 +3,7 @@ from money_flow_analyzer import MoneyflowAnalyzer
if __name__ == "__main__":
# 指定日期范围
start_date = '20230912'
start_date = '20250401'
end_date = None
# 创建数据获取器实例
@ -11,7 +11,7 @@ if __name__ == "__main__":
# 获取板块资金流向数据
# 可以通过force_update=True参数强制更新指定日期范围的数据
df = data_fetcher.get_moneyflow_ind_dc(start_date, end_date, force_update=False)
df = data_fetcher.get('moneyflow',start_date, end_date, force_update=True)
analyzer = MoneyflowAnalyzer()
analyzer.main_flow_analyze(days_forward=2,use_consistent_samples=True)
# analyzer = MoneyflowAnalyzer()
# analyzer.main_flow_analyze(days_forward=10,use_consistent_samples=True)