backtrader/data_manager.py

209 lines
7.7 KiB
Python
Raw Normal View History

import traceback
from datetime import datetime, timedelta
import pandas as pd
import tushare as ts
from tqdm import tqdm
from config_manager import get_config_manager
from database_manager import DatabaseManager
from logger import get_logger
# 模块级别的配置
config = get_config_manager()
ts.set_token(config.get('tushare_token'))
pro = ts.pro_api()
# 初始化数据库管理器
db_manager = DatabaseManager()
# 获取日志器
logger = get_logger()
class DataFetcher:
"""
数据获取器类负责从Tushare获取各类数据并管理本地缓存
"""
@staticmethod
def get_basic(api_name):
"""
获取基础数据如股票列表等
参数
api_name (str): Tushare的API名称例如'stock_basic'
返回
pandas.DataFrame: 请求的数据
"""
# 确保pro中存在对应的API方法
if not hasattr(pro, api_name):
logger.error(f"Tushare API '{api_name}'不存在")
return pd.DataFrame()
try:
df = getattr(pro, api_name)()
# 将数据保存到数据库
db_manager.save_df_to_db(df, table_name=api_name, if_exists='replace')
return df
except Exception as e:
logger.error(f"获取基础数据时出错: {e}")
logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
return pd.DataFrame()
@staticmethod
def get_trade_date(api_name, start_date=None, end_date=None, force_update=False, batch_size=100000):
"""
获取指定时间段内的Tushare数据使用数据库缓存并分批处理大量数据
参数:
api_name (str): Tushare的API名称例如'moneyflow''moneyflow_ind_dc'
start_date (str): 开始日期格式'YYYYMMDD'
end_date (str): 结束日期格式'YYYYMMDD'
force_update (bool): 是否强制更新所选区域数据默认为False
batch_size (int): 每批处理的最大行数默认为100000
返回:
pandas.DataFrame: 请求的数据
"""
# 使用api_name作为表key查询表名
table_name = api_name
# 确保pro中存在对应的API方法
if not hasattr(pro, api_name):
logger.error(f"Tushare API '{api_name}'不存在")
return pd.DataFrame()
# 获取目标交易日历
all_trade_dates = DataReader.get_trade_cal(start_date, end_date)
# 确定需要获取的日期
if not force_update:
# 从数据库获取已有的交易日期
existing_dates = DataReader.get_existing_trade_dates(table_name=table_name)
# 筛选出需要新获取的日期
dates_to_fetch = [date for date in all_trade_dates if date not in existing_dates]
else:
# 强制更新,获取所有日期数据
dates_to_fetch = all_trade_dates
logger.info(f"强制更新模式: 将更新 {len(dates_to_fetch)} 个交易日的数据")
if not dates_to_fetch:
logger.info("所有数据已在数据库中,无需更新")
return db_manager.load_df_from_db(table_name=table_name)
logger.info(f"需要获取 {len(dates_to_fetch)} 个交易日的数据")
# 分批处理数据
temp_batch = []
total_rows = 0
# 获取数据
for trade_date in tqdm(dates_to_fetch):
try:
# 动态调用Tushare API
api_method = getattr(pro, api_name)
df = api_method(trade_date=trade_date)
if not df.empty:
temp_batch.append(df)
total_rows += len(df)
# 当累积的数据量达到batch_size时进行一次批量写入
if total_rows >= batch_size:
DataFetcher.process_batch(temp_batch, table_name, force_update)
logger.info(f"已处理 {total_rows} 行数据")
# 重置临时批次
temp_batch = []
total_rows = 0
else:
logger.info(f"日期 {trade_date} 无数据")
except Exception as e:
logger.error(f"获取 {trade_date} 的数据时出错: {e}")
logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
# 处理剩余的数据
if temp_batch:
DataFetcher.process_batch(temp_batch, table_name, force_update)
logger.info(f"已处理剩余 {total_rows} 行数据")
logger.info("数据获取与处理完成")
return db_manager.load_df_from_db(table_name=table_name)
@staticmethod
def process_batch(batch_dfs, table_name, force_update):
"""
处理一批数据
参数:
batch_dfs (list): DataFrame列表
table_name (str): 数据库表名称
force_update (bool): 是否强制更新
"""
if not batch_dfs:
return
# 合并批次中的所有DataFrame
batch_df = pd.concat(batch_dfs, ignore_index=True)
# 保存数据传入force_update参数
db_manager.save_df_to_db(batch_df, table_name=table_name, force_update=force_update)
class DataReader:
"""
数据读取器类负责从数据库读取数据
"""
@staticmethod
def get_trade_cal(start_date=None, end_date=None, update=False):
"""
获取指定时间段内的交易日历
参数
start_date (str): 开始日期格式'YYYYMMDD'
end_date (str): 结束日期格式'YYYYMMDD'
update (bool): 是否更新交易日历默认为False
返回
list: 交易日期列表
"""
# 先检查表是否存在
if not db_manager.table_exists('trade_cal') or update:
logger.debug(f"表 trade_cal 不存在")
# 自动拉取交易日历
try:
DataFetcher.get_basic('trade_cal')
except Exception as e:
logger.error(f"自动拉取交易日历时出错: {e}")
logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
return []
if start_date is None:
start_date = (datetime.now() - timedelta(days=30)).strftime('%Y%m%d')
if end_date is None:
end_date = datetime.now().strftime('%Y%m%d')
try:
df = db_manager.load_df_from_db('trade_cal', conditions=f"cal_date BETWEEN '{start_date}' AND '{end_date}'")
return df[df['is_open'] == 1]['cal_date'].tolist()
except Exception as e:
logger.error(f"获取交易日历时出错: {e}")
logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
return []
@staticmethod
def get_existing_trade_dates(table_name):
"""
从数据库中获取已有的交易日期
参数:
table_name (str): 数据表名称
返回:
set: 已存在于数据库中的交易日期集合
"""
# 先检查表是否存在
if not db_manager.table_exists(table_name):
logger.debug(f"'{table_name}' 不存在")
return set()
try:
# 使用query方法获取不重复的交易日期
query_result = db_manager.query(f"SELECT DISTINCT trade_date FROM {table_name}")
# 将查询结果转换为集合并返回
return set(query_result['trade_date'].values)
except Exception as e:
logger.error(f"获取已存在交易日期时出错: {e}")
logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
return set()