112 lines
4.5 KiB
Python
112 lines
4.5 KiB
Python
|
import pandas as pd
|
|||
|
from tqdm import tqdm
|
|||
|
import tushare as ts
|
|||
|
from datetime import datetime, timedelta
|
|||
|
from utils import load_config
|
|||
|
from database_manager import DatabaseManager
|
|||
|
|
|||
|
|
|||
|
class DataFetcher:
|
|||
|
"""
|
|||
|
数据获取器类,负责从Tushare获取各类数据并管理本地缓存
|
|||
|
"""
|
|||
|
|
|||
|
def __init__(self):
|
|||
|
# 加载配置并初始化tushare
|
|||
|
self.config = load_config()
|
|||
|
ts.set_token(self.config['tushare_token'])
|
|||
|
self.pro = ts.pro_api()
|
|||
|
# 初始化数据库管理器
|
|||
|
self.db_manager = DatabaseManager()
|
|||
|
|
|||
|
def get_trade_cal(self, start_date=None, end_date=None):
|
|||
|
"""
|
|||
|
获取指定时间段内的交易日历
|
|||
|
|
|||
|
参数:
|
|||
|
start_date (str): 开始日期,格式'YYYYMMDD'
|
|||
|
end_date (str): 结束日期,格式'YYYYMMDD'
|
|||
|
|
|||
|
返回:
|
|||
|
list: 交易日期列表
|
|||
|
"""
|
|||
|
if start_date is None:
|
|||
|
start_date = (datetime.now() - timedelta(days=30)).strftime('%Y%m%d')
|
|||
|
if end_date is None:
|
|||
|
end_date = datetime.now().strftime('%Y%m%d')
|
|||
|
|
|||
|
try:
|
|||
|
trade_cal_df = self.pro.trade_cal(exchange='', start_date=start_date, end_date=end_date)
|
|||
|
return trade_cal_df[trade_cal_df['is_open'] == 1]['cal_date'].tolist()
|
|||
|
except Exception as e:
|
|||
|
print(f"获取交易日历时出错: {e}")
|
|||
|
return []
|
|||
|
|
|||
|
def get_moneyflow_ind_dc(self, start_date=None, end_date=None, force_update=False):
|
|||
|
"""
|
|||
|
获取指定时间段内的板块资金流向数据,使用数据库缓存
|
|||
|
|
|||
|
参数:
|
|||
|
start_date (str): 开始日期,格式'YYYYMMDD'
|
|||
|
end_date (str): 结束日期,格式'YYYYMMDD'
|
|||
|
force_update (bool): 是否强制更新所选区域数据,默认为False
|
|||
|
|
|||
|
返回:
|
|||
|
pandas.DataFrame: 所有板块资金流向数据
|
|||
|
"""
|
|||
|
# 获取目标交易日历
|
|||
|
all_trade_dates = self.get_trade_cal(start_date, end_date)
|
|||
|
|
|||
|
# 确定需要获取的日期
|
|||
|
if not force_update:
|
|||
|
# 从数据库获取已有的交易日期
|
|||
|
existing_dates = self.db_manager.get_existing_trade_dates(table_key='moneyflow_ind_dc')
|
|||
|
# 筛选出需要新获取的日期
|
|||
|
dates_to_fetch = [date for date in all_trade_dates if date not in existing_dates]
|
|||
|
else:
|
|||
|
# 强制更新,获取所有日期数据
|
|||
|
dates_to_fetch = all_trade_dates
|
|||
|
print(f"强制更新模式: 将更新 {len(dates_to_fetch)} 个交易日的数据")
|
|||
|
|
|||
|
if not dates_to_fetch:
|
|||
|
print("所有数据已在数据库中,无需更新")
|
|||
|
return self.db_manager.load_df_from_db(table_key='moneyflow_ind_dc')
|
|||
|
|
|||
|
print(f"需要获取 {len(dates_to_fetch)} 个交易日的数据")
|
|||
|
|
|||
|
# 获取数据
|
|||
|
all_new_data = []
|
|||
|
for trade_date in tqdm(dates_to_fetch):
|
|||
|
try:
|
|||
|
# 从tushare获取当日板块资金流向数据
|
|||
|
df = self.pro.moneyflow_ind_dc(trade_date=trade_date)
|
|||
|
if not df.empty:
|
|||
|
all_new_data.append(df)
|
|||
|
else:
|
|||
|
print(f"日期 {trade_date} 无数据")
|
|||
|
except Exception as e:
|
|||
|
print(f"获取 {trade_date} 的数据时出错: {e}")
|
|||
|
|
|||
|
# 处理新数据
|
|||
|
if all_new_data:
|
|||
|
# 将所有新数据合并为一个DataFrame
|
|||
|
new_df = pd.concat(all_new_data, ignore_index=True)
|
|||
|
|
|||
|
if force_update:
|
|||
|
# 强制更新模式:需要删除已有的日期数据,然后重新插入
|
|||
|
existing_df = self.db_manager.load_df_from_db(table_key='moneyflow_ind_dc')
|
|||
|
# 过滤掉需要更新的日期范围内的数据
|
|||
|
filtered_df = existing_df[~existing_df['trade_date'].isin(dates_to_fetch)]
|
|||
|
# 拼接新数据
|
|||
|
final_df = pd.concat([filtered_df, new_df], ignore_index=True)
|
|||
|
# 替换整个表
|
|||
|
self.db_manager.save_df_to_db(final_df, table_key='moneyflow_ind_dc', if_exists='replace')
|
|||
|
print(f"已强制更新 {len(new_df)} 条记录到数据库")
|
|||
|
else:
|
|||
|
# 普通追加模式
|
|||
|
self.db_manager.save_df_to_db(new_df, table_key='moneyflow_ind_dc', if_exists='append')
|
|||
|
print(f"已将 {len(new_df)} 条新记录追加到数据库")
|
|||
|
else:
|
|||
|
print("未获取到任何新数据")
|
|||
|
|
|||
|
return self.db_manager.load_df_from_db(table_key='moneyflow_ind_dc')
|