backtrader/data_fetcher.py

114 lines
4.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import traceback
from datetime import datetime, timedelta
import pandas as pd
import tushare as ts
from tqdm import tqdm
from database_manager import DatabaseManager
from logger import get_logger
from utils import load_config
class DataFetcher:
"""
数据获取器类负责从Tushare获取各类数据并管理本地缓存
"""
def __init__(self):
# 加载配置并初始化tushare
self.config = load_config()
ts.set_token(self.config['tushare_token'])
self.pro = ts.pro_api()
# 初始化数据库管理器
self.db_manager = DatabaseManager()
# 获取日志器
self.logger = get_logger()
def get_trade_cal(self, start_date=None, end_date=None):
"""
获取指定时间段内的交易日历
参数:
start_date (str): 开始日期,格式'YYYYMMDD'
end_date (str): 结束日期,格式'YYYYMMDD'
返回:
list: 交易日期列表
"""
if start_date is None:
start_date = (datetime.now() - timedelta(days=30)).strftime('%Y%m%d')
if end_date is None:
end_date = datetime.now().strftime('%Y%m%d')
try:
trade_cal_df = self.pro.trade_cal(exchange='', start_date=start_date, end_date=end_date)
return trade_cal_df[trade_cal_df['is_open'] == 1]['cal_date'].tolist()
except Exception as e:
self.logger.error(f"获取交易日历时出错: {e}")
self.logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
return []
def get_moneyflow_ind_dc(self, start_date=None, end_date=None, force_update=False):
"""
获取指定时间段内的板块资金流向数据,使用数据库缓存
参数:
start_date (str): 开始日期,格式'YYYYMMDD'
end_date (str): 结束日期,格式'YYYYMMDD'
force_update (bool): 是否强制更新所选区域数据默认为False
返回:
pandas.DataFrame: 所有板块资金流向数据
"""
# 获取目标交易日历
all_trade_dates = self.get_trade_cal(start_date, end_date)
# 确定需要获取的日期
if not force_update:
# 从数据库获取已有的交易日期
existing_dates = self.db_manager.get_existing_trade_dates(table_key='moneyflow_ind_dc')
# 筛选出需要新获取的日期
dates_to_fetch = [date for date in all_trade_dates if date not in existing_dates]
else:
# 强制更新,获取所有日期数据
dates_to_fetch = all_trade_dates
self.logger.info(f"强制更新模式: 将更新 {len(dates_to_fetch)} 个交易日的数据")
if not dates_to_fetch:
self.logger.info("所有数据已在数据库中,无需更新")
return self.db_manager.load_df_from_db(table_key='moneyflow_ind_dc')
self.logger.info(f"需要获取 {len(dates_to_fetch)} 个交易日的数据")
# 获取数据
all_new_data = []
for trade_date in tqdm(dates_to_fetch):
try:
# 从tushare获取当日板块资金流向数据
df = self.pro.moneyflow_ind_dc(trade_date=trade_date)
if not df.empty:
all_new_data.append(df)
else:
self.logger.info(f"日期 {trade_date} 无数据")
except Exception as e:
self.logger.error(f"获取 {trade_date} 的数据时出错: {e}")
self.logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
# 处理新数据
if all_new_data:
# 将所有新数据合并为一个DataFrame
new_df = pd.concat(all_new_data, ignore_index=True)
if force_update:
# 强制更新模式:需要删除已有的日期数据,然后重新插入
existing_df = self.db_manager.load_df_from_db(table_key='moneyflow_ind_dc')
# 过滤掉需要更新的日期范围内的数据
filtered_df = existing_df[~existing_df['trade_date'].isin(dates_to_fetch)]
# 拼接新数据
final_df = pd.concat([filtered_df, new_df], ignore_index=True)
# 替换整个表
self.db_manager.save_df_to_db(final_df, table_key='moneyflow_ind_dc', if_exists='replace')
self.logger.info(f"已强制更新 {len(new_df)} 条记录到数据库")
else:
# 普通追加模式
self.db_manager.save_df_to_db(new_df, table_key='moneyflow_ind_dc', if_exists='append')
self.logger.info(f"已将 {len(new_df)} 条新记录追加到数据库")
else:
self.logger.info("未获取到任何新数据")
return self.db_manager.load_df_from_db(table_key='moneyflow_ind_dc')