backtrader/data_manager.py
Qihang Zhang 5c22360b58 🔨 refactor(data): 重构数据管理模块,提取独立的DataReader类并优化API接口
为了提升代码结构和可维护性,将data_fetcher.py重命名为data_manager.py,并进行以下重构:

1. 将实例变量移至模块级别配置
2. 将实例方法转换为静态方法
3. 提取新的DataReader类用于数据读取操作
4. 在DatabaseManager中添加通用查询方法
5. 优化数据获取与缓存逻辑
2025-04-19 19:24:24 +08:00

209 lines
7.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import traceback
from datetime import datetime, timedelta
import pandas as pd
import tushare as ts
from tqdm import tqdm
from config_manager import get_config_manager
from database_manager import DatabaseManager
from logger import get_logger
# 模块级别的配置
config = get_config_manager()
ts.set_token(config.get('tushare_token'))
pro = ts.pro_api()
# 初始化数据库管理器
db_manager = DatabaseManager()
# 获取日志器
logger = get_logger()
class DataFetcher:
"""
数据获取器类负责从Tushare获取各类数据并管理本地缓存
"""
@staticmethod
def get_basic(api_name):
"""
获取基础数据,如股票列表等
参数:
api_name (str): Tushare的API名称例如'stock_basic'
返回:
pandas.DataFrame: 请求的数据
"""
# 确保pro中存在对应的API方法
if not hasattr(pro, api_name):
logger.error(f"Tushare API '{api_name}'不存在")
return pd.DataFrame()
try:
df = getattr(pro, api_name)()
# 将数据保存到数据库
db_manager.save_df_to_db(df, table_name=api_name, if_exists='replace')
return df
except Exception as e:
logger.error(f"获取基础数据时出错: {e}")
logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
return pd.DataFrame()
@staticmethod
def get_trade_date(api_name, start_date=None, end_date=None, force_update=False, batch_size=100000):
"""
获取指定时间段内的Tushare数据使用数据库缓存并分批处理大量数据
参数:
api_name (str): Tushare的API名称例如'moneyflow''moneyflow_ind_dc'
start_date (str): 开始日期,格式'YYYYMMDD'
end_date (str): 结束日期,格式'YYYYMMDD'
force_update (bool): 是否强制更新所选区域数据默认为False
batch_size (int): 每批处理的最大行数默认为100000
返回:
pandas.DataFrame: 请求的数据
"""
# 使用api_name作为表key查询表名
table_name = api_name
# 确保pro中存在对应的API方法
if not hasattr(pro, api_name):
logger.error(f"Tushare API '{api_name}'不存在")
return pd.DataFrame()
# 获取目标交易日历
all_trade_dates = DataReader.get_trade_cal(start_date, end_date)
# 确定需要获取的日期
if not force_update:
# 从数据库获取已有的交易日期
existing_dates = DataReader.get_existing_trade_dates(table_name=table_name)
# 筛选出需要新获取的日期
dates_to_fetch = [date for date in all_trade_dates if date not in existing_dates]
else:
# 强制更新,获取所有日期数据
dates_to_fetch = all_trade_dates
logger.info(f"强制更新模式: 将更新 {len(dates_to_fetch)} 个交易日的数据")
if not dates_to_fetch:
logger.info("所有数据已在数据库中,无需更新")
return db_manager.load_df_from_db(table_name=table_name)
logger.info(f"需要获取 {len(dates_to_fetch)} 个交易日的数据")
# 分批处理数据
temp_batch = []
total_rows = 0
# 获取数据
for trade_date in tqdm(dates_to_fetch):
try:
# 动态调用Tushare API
api_method = getattr(pro, api_name)
df = api_method(trade_date=trade_date)
if not df.empty:
temp_batch.append(df)
total_rows += len(df)
# 当累积的数据量达到batch_size时进行一次批量写入
if total_rows >= batch_size:
DataFetcher.process_batch(temp_batch, table_name, force_update)
logger.info(f"已处理 {total_rows} 行数据")
# 重置临时批次
temp_batch = []
total_rows = 0
else:
logger.info(f"日期 {trade_date} 无数据")
except Exception as e:
logger.error(f"获取 {trade_date} 的数据时出错: {e}")
logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
# 处理剩余的数据
if temp_batch:
DataFetcher.process_batch(temp_batch, table_name, force_update)
logger.info(f"已处理剩余 {total_rows} 行数据")
logger.info("数据获取与处理完成")
return db_manager.load_df_from_db(table_name=table_name)
@staticmethod
def process_batch(batch_dfs, table_name, force_update):
"""
处理一批数据
参数:
batch_dfs (list): DataFrame列表
table_name (str): 数据库表名称
force_update (bool): 是否强制更新
"""
if not batch_dfs:
return
# 合并批次中的所有DataFrame
batch_df = pd.concat(batch_dfs, ignore_index=True)
# 保存数据传入force_update参数
db_manager.save_df_to_db(batch_df, table_name=table_name, force_update=force_update)
class DataReader:
"""
数据读取器类,负责从数据库读取数据
"""
@staticmethod
def get_trade_cal(start_date=None, end_date=None, update=False):
"""
获取指定时间段内的交易日历
参数:
start_date (str): 开始日期,格式'YYYYMMDD'
end_date (str): 结束日期,格式'YYYYMMDD'
update (bool): 是否更新交易日历默认为False
返回:
list: 交易日期列表
"""
# 先检查表是否存在
if not db_manager.table_exists('trade_cal') or update:
logger.debug(f"表 trade_cal 不存在")
# 自动拉取交易日历
try:
DataFetcher.get_basic('trade_cal')
except Exception as e:
logger.error(f"自动拉取交易日历时出错: {e}")
logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
return []
if start_date is None:
start_date = (datetime.now() - timedelta(days=30)).strftime('%Y%m%d')
if end_date is None:
end_date = datetime.now().strftime('%Y%m%d')
try:
df = db_manager.load_df_from_db('trade_cal', conditions=f"cal_date BETWEEN '{start_date}' AND '{end_date}'")
return df[df['is_open'] == 1]['cal_date'].tolist()
except Exception as e:
logger.error(f"获取交易日历时出错: {e}")
logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
return []
@staticmethod
def get_existing_trade_dates(table_name):
"""
从数据库中获取已有的交易日期
参数:
table_name (str): 数据表名称
返回:
set: 已存在于数据库中的交易日期集合
"""
# 先检查表是否存在
if not db_manager.table_exists(table_name):
logger.debug(f"'{table_name}' 不存在")
return set()
try:
# 使用query方法获取不重复的交易日期
query_result = db_manager.query(f"SELECT DISTINCT trade_date FROM {table_name}")
# 将查询结果转换为集合并返回
return set(query_result['trade_date'].values)
except Exception as e:
logger.error(f"获取已存在交易日期时出错: {e}")
logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
return set()