2025-04-19 13:59:06 +08:00
|
|
|
|
import traceback
|
|
|
|
|
from datetime import datetime, timedelta
|
|
|
|
|
|
2025-04-19 13:25:54 +08:00
|
|
|
|
import pandas as pd
|
|
|
|
|
import tushare as ts
|
2025-04-19 13:59:06 +08:00
|
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
2025-04-19 14:22:18 +08:00
|
|
|
|
from config_manager import get_config_manager
|
2025-04-19 13:25:54 +08:00
|
|
|
|
from database_manager import DatabaseManager
|
2025-04-19 13:59:06 +08:00
|
|
|
|
from logger import get_logger
|
2025-04-19 13:25:54 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DataFetcher:
|
|
|
|
|
"""
|
|
|
|
|
数据获取器类,负责从Tushare获取各类数据并管理本地缓存
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
# 加载配置并初始化tushare
|
2025-04-19 14:22:18 +08:00
|
|
|
|
self.config = get_config_manager()
|
|
|
|
|
ts.set_token(self.config.get('tushare_token'))
|
2025-04-19 13:25:54 +08:00
|
|
|
|
self.pro = ts.pro_api()
|
2025-04-19 14:22:18 +08:00
|
|
|
|
|
2025-04-19 13:25:54 +08:00
|
|
|
|
# 初始化数据库管理器
|
|
|
|
|
self.db_manager = DatabaseManager()
|
2025-04-19 14:22:18 +08:00
|
|
|
|
|
2025-04-19 13:59:06 +08:00
|
|
|
|
# 获取日志器
|
|
|
|
|
self.logger = get_logger()
|
2025-04-19 13:25:54 +08:00
|
|
|
|
|
|
|
|
|
def get_trade_cal(self, start_date=None, end_date=None):
|
|
|
|
|
"""
|
|
|
|
|
获取指定时间段内的交易日历
|
|
|
|
|
参数:
|
|
|
|
|
start_date (str): 开始日期,格式'YYYYMMDD'
|
|
|
|
|
end_date (str): 结束日期,格式'YYYYMMDD'
|
|
|
|
|
返回:
|
|
|
|
|
list: 交易日期列表
|
|
|
|
|
"""
|
|
|
|
|
if start_date is None:
|
|
|
|
|
start_date = (datetime.now() - timedelta(days=30)).strftime('%Y%m%d')
|
|
|
|
|
if end_date is None:
|
|
|
|
|
end_date = datetime.now().strftime('%Y%m%d')
|
|
|
|
|
try:
|
|
|
|
|
trade_cal_df = self.pro.trade_cal(exchange='', start_date=start_date, end_date=end_date)
|
|
|
|
|
return trade_cal_df[trade_cal_df['is_open'] == 1]['cal_date'].tolist()
|
|
|
|
|
except Exception as e:
|
2025-04-19 13:59:06 +08:00
|
|
|
|
self.logger.error(f"获取交易日历时出错: {e}")
|
|
|
|
|
self.logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
|
2025-04-19 13:25:54 +08:00
|
|
|
|
return []
|
|
|
|
|
|
2025-04-19 16:51:53 +08:00
|
|
|
|
def get(self, api_name, start_date=None, end_date=None, force_update=False, batch_size=100000):
|
2025-04-19 13:25:54 +08:00
|
|
|
|
"""
|
2025-04-19 16:51:53 +08:00
|
|
|
|
获取指定时间段内的Tushare数据,使用数据库缓存,并分批处理大量数据
|
|
|
|
|
|
2025-04-19 13:25:54 +08:00
|
|
|
|
参数:
|
2025-04-19 16:51:53 +08:00
|
|
|
|
api_name (str): Tushare的API名称,例如'moneyflow'或'moneyflow_ind_dc'
|
2025-04-19 13:25:54 +08:00
|
|
|
|
start_date (str): 开始日期,格式'YYYYMMDD'
|
|
|
|
|
end_date (str): 结束日期,格式'YYYYMMDD'
|
|
|
|
|
force_update (bool): 是否强制更新所选区域数据,默认为False
|
2025-04-19 16:51:53 +08:00
|
|
|
|
batch_size (int): 每批处理的最大行数,默认为100000
|
|
|
|
|
|
2025-04-19 13:25:54 +08:00
|
|
|
|
返回:
|
2025-04-19 16:51:53 +08:00
|
|
|
|
pandas.DataFrame: 请求的数据
|
2025-04-19 13:25:54 +08:00
|
|
|
|
"""
|
2025-04-19 16:51:53 +08:00
|
|
|
|
# 使用api_name作为表key查询表名
|
|
|
|
|
table_key = api_name
|
|
|
|
|
|
|
|
|
|
# 确保self.pro中存在对应的API方法
|
|
|
|
|
if not hasattr(self.pro, api_name):
|
|
|
|
|
self.logger.error(f"Tushare API '{api_name}'不存在")
|
|
|
|
|
return pd.DataFrame()
|
|
|
|
|
|
2025-04-19 13:25:54 +08:00
|
|
|
|
# 获取目标交易日历
|
|
|
|
|
all_trade_dates = self.get_trade_cal(start_date, end_date)
|
|
|
|
|
|
|
|
|
|
# 确定需要获取的日期
|
|
|
|
|
if not force_update:
|
|
|
|
|
# 从数据库获取已有的交易日期
|
2025-04-19 14:35:28 +08:00
|
|
|
|
existing_dates = self.db_manager.get_existing_trade_dates(table_key=table_key)
|
2025-04-19 13:25:54 +08:00
|
|
|
|
# 筛选出需要新获取的日期
|
|
|
|
|
dates_to_fetch = [date for date in all_trade_dates if date not in existing_dates]
|
|
|
|
|
else:
|
|
|
|
|
# 强制更新,获取所有日期数据
|
|
|
|
|
dates_to_fetch = all_trade_dates
|
2025-04-19 13:59:06 +08:00
|
|
|
|
self.logger.info(f"强制更新模式: 将更新 {len(dates_to_fetch)} 个交易日的数据")
|
2025-04-19 13:25:54 +08:00
|
|
|
|
|
|
|
|
|
if not dates_to_fetch:
|
2025-04-19 13:59:06 +08:00
|
|
|
|
self.logger.info("所有数据已在数据库中,无需更新")
|
2025-04-19 14:35:28 +08:00
|
|
|
|
return self.db_manager.load_df_from_db(table_key=table_key)
|
2025-04-19 13:25:54 +08:00
|
|
|
|
|
2025-04-19 13:59:06 +08:00
|
|
|
|
self.logger.info(f"需要获取 {len(dates_to_fetch)} 个交易日的数据")
|
2025-04-19 13:25:54 +08:00
|
|
|
|
|
2025-04-19 16:51:53 +08:00
|
|
|
|
# 分批处理数据
|
|
|
|
|
temp_batch = []
|
|
|
|
|
total_rows = 0
|
|
|
|
|
|
2025-04-19 13:25:54 +08:00
|
|
|
|
# 获取数据
|
|
|
|
|
for trade_date in tqdm(dates_to_fetch):
|
|
|
|
|
try:
|
2025-04-19 16:51:53 +08:00
|
|
|
|
# 动态调用Tushare API
|
|
|
|
|
api_method = getattr(self.pro, api_name)
|
|
|
|
|
df = api_method(trade_date=trade_date)
|
|
|
|
|
|
2025-04-19 13:25:54 +08:00
|
|
|
|
if not df.empty:
|
2025-04-19 16:51:53 +08:00
|
|
|
|
temp_batch.append(df)
|
|
|
|
|
total_rows += len(df)
|
|
|
|
|
|
|
|
|
|
# 当累积的数据量达到batch_size时,进行一次批量写入
|
|
|
|
|
if total_rows >= batch_size:
|
|
|
|
|
self._process_batch(temp_batch, table_key, force_update)
|
|
|
|
|
self.logger.info(f"已处理 {total_rows} 行数据")
|
|
|
|
|
# 重置临时批次
|
|
|
|
|
temp_batch = []
|
|
|
|
|
total_rows = 0
|
2025-04-19 13:25:54 +08:00
|
|
|
|
else:
|
2025-04-19 13:59:06 +08:00
|
|
|
|
self.logger.info(f"日期 {trade_date} 无数据")
|
2025-04-19 13:25:54 +08:00
|
|
|
|
except Exception as e:
|
2025-04-19 13:59:06 +08:00
|
|
|
|
self.logger.error(f"获取 {trade_date} 的数据时出错: {e}")
|
|
|
|
|
self.logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
|
2025-04-19 13:25:54 +08:00
|
|
|
|
|
2025-04-19 16:51:53 +08:00
|
|
|
|
# 处理剩余的数据
|
|
|
|
|
if temp_batch:
|
|
|
|
|
self._process_batch(temp_batch, table_key, force_update)
|
|
|
|
|
self.logger.info(f"已处理剩余 {total_rows} 行数据")
|
|
|
|
|
|
|
|
|
|
self.logger.info("数据获取与处理完成")
|
|
|
|
|
return self.db_manager.load_df_from_db(table_key=table_key)
|
2025-04-19 13:25:54 +08:00
|
|
|
|
|
2025-04-19 16:51:53 +08:00
|
|
|
|
def _process_batch(self, batch_dfs, table_key, force_update):
|
|
|
|
|
"""
|
|
|
|
|
处理一批数据
|
|
|
|
|
参数:
|
|
|
|
|
batch_dfs (list): DataFrame列表
|
|
|
|
|
table_key (str): 数据库表名
|
|
|
|
|
force_update (bool): 是否强制更新
|
|
|
|
|
"""
|
|
|
|
|
if not batch_dfs:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# 合并批次中的所有DataFrame
|
|
|
|
|
batch_df = pd.concat(batch_dfs, ignore_index=True)
|
|
|
|
|
|
|
|
|
|
if force_update:
|
|
|
|
|
# 强制更新模式:先删除当前批次涉及的日期数据,然后插入新数据
|
|
|
|
|
current_dates = batch_df['trade_date'].unique().tolist()
|
|
|
|
|
|
|
|
|
|
# 删除这些日期的现有数据
|
|
|
|
|
self.db_manager.delete_existing_data_by_dates(table_key, current_dates)
|
|
|
|
|
|
|
|
|
|
# 插入新数据
|
|
|
|
|
self.db_manager.save_df_to_db(batch_df, table_key=table_key, if_exists='append')
|
|
|
|
|
else:
|
|
|
|
|
# 普通追加模式
|
|
|
|
|
self.db_manager.save_df_to_db(batch_df, table_key=table_key, if_exists='append')
|