backtrader/data_fetcher.py

155 lines
5.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import traceback
from datetime import datetime, timedelta
import pandas as pd
import tushare as ts
from tqdm import tqdm
from config_manager import get_config_manager
from database_manager import DatabaseManager
from logger import get_logger
class DataFetcher:
"""
数据获取器类负责从Tushare获取各类数据并管理本地缓存
"""
def __init__(self):
# 加载配置并初始化tushare
self.config = get_config_manager()
ts.set_token(self.config.get('tushare_token'))
self.pro = ts.pro_api()
# 初始化数据库管理器
self.db_manager = DatabaseManager()
# 获取日志器
self.logger = get_logger()
def get_trade_cal(self, start_date=None, end_date=None):
"""
获取指定时间段内的交易日历
参数:
start_date (str): 开始日期,格式'YYYYMMDD'
end_date (str): 结束日期,格式'YYYYMMDD'
返回:
list: 交易日期列表
"""
if start_date is None:
start_date = (datetime.now() - timedelta(days=30)).strftime('%Y%m%d')
if end_date is None:
end_date = datetime.now().strftime('%Y%m%d')
try:
trade_cal_df = self.pro.trade_cal(exchange='', start_date=start_date, end_date=end_date)
return trade_cal_df[trade_cal_df['is_open'] == 1]['cal_date'].tolist()
except Exception as e:
self.logger.error(f"获取交易日历时出错: {e}")
self.logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
return []
def get(self, api_name, start_date=None, end_date=None, force_update=False, batch_size=100000):
"""
获取指定时间段内的Tushare数据使用数据库缓存并分批处理大量数据
参数:
api_name (str): Tushare的API名称例如'moneyflow''moneyflow_ind_dc'
start_date (str): 开始日期,格式'YYYYMMDD'
end_date (str): 结束日期,格式'YYYYMMDD'
force_update (bool): 是否强制更新所选区域数据默认为False
batch_size (int): 每批处理的最大行数默认为100000
返回:
pandas.DataFrame: 请求的数据
"""
# 使用api_name作为表key查询表名
table_name = api_name
# 确保self.pro中存在对应的API方法
if not hasattr(self.pro, api_name):
self.logger.error(f"Tushare API '{api_name}'不存在")
return pd.DataFrame()
# 获取目标交易日历
all_trade_dates = self.get_trade_cal(start_date, end_date)
# 确定需要获取的日期
if not force_update:
# 从数据库获取已有的交易日期
existing_dates = self.db_manager.get_existing_trade_dates(table_name=table_name)
# 筛选出需要新获取的日期
dates_to_fetch = [date for date in all_trade_dates if date not in existing_dates]
else:
# 强制更新,获取所有日期数据
dates_to_fetch = all_trade_dates
self.logger.info(f"强制更新模式: 将更新 {len(dates_to_fetch)} 个交易日的数据")
if not dates_to_fetch:
self.logger.info("所有数据已在数据库中,无需更新")
return self.db_manager.load_df_from_db(table_name=table_name)
self.logger.info(f"需要获取 {len(dates_to_fetch)} 个交易日的数据")
# 分批处理数据
temp_batch = []
total_rows = 0
# 获取数据
for trade_date in tqdm(dates_to_fetch):
try:
# 动态调用Tushare API
api_method = getattr(self.pro, api_name)
df = api_method(trade_date=trade_date)
if not df.empty:
temp_batch.append(df)
total_rows += len(df)
# 当累积的数据量达到batch_size时进行一次批量写入
if total_rows >= batch_size:
self._process_batch(temp_batch, table_name, force_update)
self.logger.info(f"已处理 {total_rows} 行数据")
# 重置临时批次
temp_batch = []
total_rows = 0
else:
self.logger.info(f"日期 {trade_date} 无数据")
except Exception as e:
self.logger.error(f"获取 {trade_date} 的数据时出错: {e}")
self.logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
# 处理剩余的数据
if temp_batch:
self._process_batch(temp_batch, table_name, force_update)
self.logger.info(f"已处理剩余 {total_rows} 行数据")
self.logger.info("数据获取与处理完成")
return self.db_manager.load_df_from_db(table_name=table_name)
def _process_batch(self, batch_dfs, table_name, force_update):
"""
处理一批数据
参数:
batch_dfs (list): DataFrame列表
table_name (str): 数据库表名称
force_update (bool): 是否强制更新
"""
if not batch_dfs:
return
# 合并批次中的所有DataFrame
batch_df = pd.concat(batch_dfs, ignore_index=True)
if force_update:
# 强制更新模式:先删除当前批次涉及的日期数据,然后插入新数据
current_dates = batch_df['trade_date'].unique().tolist()
# 删除这些日期的现有数据
self.db_manager.delete_existing_data_by_dates(table_name, current_dates)
# 插入新数据
self.db_manager.save_df_to_db(batch_df, table_name=table_name, if_exists='append')
else:
# 普通追加模式
self.db_manager.save_df_to_db(batch_df, table_name=table_name, if_exists='append')