backtrader/data_manager.py

496 lines
21 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import traceback
from datetime import datetime, timedelta
from time import sleep
import pandas as pd
import tushare as ts
from tqdm import tqdm
from config_manager import get_config_manager
from database_manager import DatabaseManager
from logger_manager import get_logger
# 模块级别的配置
config = get_config_manager()
ts.set_token(config.get('tushare_token'))
pro = ts.pro_api()
# 初始化数据库管理器
db_manager = DatabaseManager()
# 获取日志器
logger = get_logger()
class DataFetcher:
"""
数据获取器类负责从Tushare获取各类数据并管理本地缓存
"""
@staticmethod
def get_basic(api_name):
"""
获取基础数据,如股票列表等
参数:
api_name (str): Tushare的API名称例如'stock_basic'
返回:
pandas.DataFrame: 请求的数据
"""
# 确保pro中存在对应的API方法
if not hasattr(pro, api_name):
logger.error(f"Tushare API '{api_name}'不存在")
return False
try:
df = getattr(pro, api_name)()
# 将数据保存到数据库
db_manager.save_df_to_db(df, table_name=api_name, if_exists='replace')
return True
except Exception as e:
logger.error(f"获取基础数据时出错: {e}")
logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
return False
@staticmethod
def get_trade_date(api_name, start_date=None, end_date=None, force_update=False, batch_size=100000):
"""
获取指定时间段内的Tushare数据使用数据库缓存并分批处理大量数据
参数:
api_name (str): Tushare的API名称例如'moneyflow''moneyflow_ind_dc'
start_date (str): 开始日期,格式'YYYYMMDD'
end_date (str): 结束日期,格式'YYYYMMDD'
force_update (bool): 是否强制更新所选区域数据默认为False
batch_size (int): 每批处理的最大行数默认为100000
返回:
pandas.DataFrame: 请求的数据
"""
# 使用api_name作为表key查询表名
table_name = api_name
# 确保pro中存在对应的API方法
if not hasattr(pro, api_name):
logger.error(f"Tushare API '{api_name}'不存在")
return False
# 获取目标交易日历
all_trade_dates = DataReader.get_trade_cal(start_date, end_date)
# 确定需要获取的日期
if not force_update:
# 从数据库获取已有的交易日期
existing_dates = DataReader.get_existing_table_trade_dates(table_name=table_name)
# 筛选出需要新获取的日期
dates_to_fetch = [date for date in all_trade_dates if date not in existing_dates]
else:
# 强制更新,获取所有日期数据
dates_to_fetch = all_trade_dates
logger.info(f"强制更新模式: 将更新 {len(dates_to_fetch)} 个交易日的数据")
if not dates_to_fetch:
logger.info("所有数据已在数据库中,无需更新")
return True
logger.info(f"需要获取 {len(dates_to_fetch)} 个交易日的数据")
# 分批处理数据
temp_batch = []
total_rows = 0
# 获取数据
for trade_date in tqdm(dates_to_fetch):
try:
# 动态调用Tushare API
api_method = getattr(pro, api_name)
df = api_method(trade_date=trade_date)
if not df.empty:
temp_batch.append(df)
total_rows += len(df)
# 当累积的数据量达到batch_size时进行一次批量写入
if total_rows >= batch_size:
DataFetcher.process_batch(temp_batch, table_name, force_update)
logger.info(f"已处理 {total_rows} 行数据")
# 重置临时批次
temp_batch = []
total_rows = 0
else:
logger.info(f"日期 {trade_date} 无数据")
except Exception as e:
logger.error(f"获取 {trade_date} 的数据时出错: {e}")
logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
sleep(5)
# 处理剩余的数据
if temp_batch:
DataFetcher.process_batch(temp_batch, table_name, force_update)
logger.info(f"已处理剩余 {total_rows} 行数据")
logger.info("数据获取与处理完成")
return True
@staticmethod
def process_batch(batch_dfs, table_name, force_update):
"""
处理一批数据
参数:
batch_dfs (list): DataFrame列表
table_name (str): 数据库表名称
force_update (bool): 是否强制更新
"""
if not batch_dfs:
return
# 合并批次中的所有DataFrame
batch_df = pd.concat(batch_dfs, ignore_index=True)
# 保存数据传入force_update参数
db_manager.save_df_to_db(batch_df, table_name=table_name, force_update=force_update)
@staticmethod
def update_all(trade_date_api=None, basic_api=None, start_date=None, end_date=None, force_update=False):
"""
更新所有数据
"""
# 获取所有API名称
if trade_date_api is None:
trade_date_api = [
'moneyflow', # 个股资金流向
'moneyflow_ind_dc', # 东财概念及行业板块资金流向DC
'daily', # A股日线行情
'daily_basic', # 每日指标,获取全部股票每日重要的基本面指标
'stk_limit', # 每日涨跌停价格
'cyq_perf', # 每日筹码及胜率
'moneyflow_ths', # 同花顺资金流向 20241219以前就没了
'moneyflow_dc', # 东方财富资金流向
'moneyflow_cnt_ths',# 同花顺概念板块资金流向THS 20240120以前就没了
'moneyflow_ind_ths',# 同花顺行业板块资金流向THS 20250120以前就没了
'kpl_concept', # 开盘啦题材库,获取开盘啦概念题材列表 20241020以前就没了
'kpl_concept_cons', # 开盘啦题材成分,获取开盘啦概念题材的成分股 20241014以前就没了
'kpl_list', # 获取开盘啦涨停、跌停、炸板等榜单数据
'top_list', # 龙虎榜每日明细
'top_inst', # 龙虎榜机构席位明细
'limit_list_d', # 涨跌停列表(新),获取A股每日涨跌停、炸板数据情况数据从2020年开始不提供ST股票的统计
'ths_daily', # 同花顺板块指数行情
'dc_index', # 东方财富概念板块,获取东方财富每个交易日的概念板块数据,支持按日期查询
'stk_auction', # 当日集合竞价,获取当日个股和ETF的集合竞价成交情况每天9点25后可以获取当日的集合竞价成交数据
'ths_hot', # 获取同花顺App热榜数据包括热股、概念板块、ETF、可转债、港美股等等每日盘中提取4次收盘后4次最晚22点提取一次。
]
if basic_api is None:
basic_api = [
'stock_basic', # 股票基本信息
'trade_cal', # 交易日历
'namechange', # 股票曾用名
'ths_index', # 同花顺概念和行业指数
'hm_list', # 游资名录
'index_basic', # 指数基本信息
]
# 使用get_trade_date更新trade_date_api列表中的所有API
for api in trade_date_api:
logger.info(f"更新API: {api}")
DataFetcher.get_trade_date(api_name=api, start_date=start_date, end_date=end_date, force_update=force_update)
# 使用get_basic更新basic_api列表中的所有API
for api in basic_api:
logger.info(f"更新API: {api}")
DataFetcher.get_basic(api_name=api)
# 更新新闻数据
logger.info("更新新闻数据")
DataFetcher.get_news()
@staticmethod
def get_news(src='sina', fields=None, start_date=None, end_date=None):
"""
智能获取新闻数据,可获取指定时间范围内的数据
参数:
src (str): 新闻来源,如'sina''wallstreetcn'
fields (str): 需要获取的字段,默认为'datetime,title,channels,content'
start_date (str): 可选的开始日期,格式'YYYY-MM-DD HH:MM:SS'
end_date (str): 可选的结束日期,格式'YYYY-MM-DD HH:MM:SS',默认为当前时间
返回:
bool: 是否成功获取数据
"""
table_name = 'news'
# 固定的历史起始点
HISTORY_START = '2025-04-01 00:00:00'
# 如果未指定字段,设置默认字段
if fields is None:
fields = 'datetime,title,channels,content'
# 确定结束时间
if end_date is None:
end_date = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
# 确定开始时间
if start_date:
logger.info(f"使用用户指定的开始日期: {start_date}")
# 如果未指定开始日期,使用智能逻辑确定
elif not db_manager.table_exists(table_name):
logger.info(f"{table_name} 不存在,将获取从 {HISTORY_START}{end_date} 的新闻数据")
start_date = HISTORY_START
else:
# 表存在,查询符合条件的最新的新闻时间
try:
# 查询小于等于end_date的最新新闻时间
newest_query = f"SELECT MAX(datetime) as max_time FROM {table_name} WHERE datetime <= '{end_date}'"
newest_result = db_manager.query(newest_query)
# 如果有数据并且值不为None
if not newest_result.empty and newest_result['max_time'].iloc[0] is not None:
newest_time = newest_result['max_time'].iloc[0]
logger.info(f"数据库中小于等于 {end_date} 的最新新闻时间为: {newest_time}")
# 从最新时间开始获取新数据
start_date = newest_time
else:
# 数据库表存在但没有符合条件的数据
logger.info(
f"数据库表 {table_name} 存在但没有小于等于 {end_date} 的数据,将获取从 {HISTORY_START}{end_date} 的新闻数据")
start_date = HISTORY_START
except Exception as e:
logger.error(f"查询最新新闻时间时出错: {e}")
logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
# 出错时,从固定起始点获取
start_date = HISTORY_START
# 检查时间范围有效性
start_datetime = datetime.strptime(start_date, '%Y-%m-%d %H:%M:%S')
end_datetime = datetime.strptime(end_date, '%Y-%m-%d %H:%M:%S')
if start_datetime >= end_datetime:
logger.info(f"开始时间 {start_date} 不早于结束时间 {end_date},无需获取数据")
return True
logger.info(f"开始获取 {start_date}{end_date} 的新闻数据")
return DataFetcher._fetch_news_recursive(src, start_date, end_date, fields, table_name)
@staticmethod
def _fetch_news_recursive(src, start_date, end_date, fields, table_name, batch_count=0):
"""
递归获取新闻数据处理API返回数据限制
参数:
src (str): 新闻来源
start_date (str): 开始日期,格式'YYYY-MM-DD HH:MM:SS'
end_date (str): 结束日期,格式'YYYY-MM-DD HH:MM:SS'
fields (str): 需要获取的字段
table_name (str): 数据库表名
batch_count (int): 批次计数,用于日志
返回:
bool: 是否成功获取数据
"""
try:
logger.info(f"获取新闻数据批次 #{batch_count + 1}: {start_date}{end_date}")
# 调用Tushare API获取新闻数据
df = pro.news(src=src, start_date=start_date, end_date=end_date, fields=fields)
if df.empty:
logger.info(f"当前时间范围内无新闻数据")
return True
# 保存数据到数据库
db_manager.save_df_to_db(df, table_name=table_name, if_exists='append')
logger.info(f"成功保存 {len(df)} 条新闻数据")
# 如果返回的数据接近限制数量(1500条),可能还有更多数据
if len(df) >= 1400: # 接近限制,设置一个略小的值作为阈值
# 找到当前批次中最早的新闻时间
earliest_date = df['datetime'].min()
# 检查是否已经达到了期望的起始日期
original_start = datetime.strptime(start_date, '%Y-%m-%d %H:%M:%S')
current_earliest = datetime.strptime(earliest_date, '%Y-%m-%d %H:%M:%S')
if current_earliest <= original_start:
logger.info(f"已获取到指定开始日期的数据,获取完成")
return True
# 将最早日期作为新的结束日期,继续获取更早的数据
new_end_date = earliest_date
# 递归调用,获取更早的数据
logger.info(f"当前批次数据量接近API限制继续获取更早数据")
sleep(1) # 避免频繁调用API
return DataFetcher._fetch_news_recursive(src, start_date, new_end_date, fields,
table_name, batch_count + 1)
return True
except Exception as e:
logger.error(f"获取新闻数据时出错: {e}")
logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
return False
class DataReader:
"""
数据读取器类,负责从数据库读取数据
"""
@staticmethod
def get_trade_cal(start_date=None, end_date=None, update=False):
"""
从数据库获取交易日历,如果表不存在则自动拉取数据,并截取指定日期范围内的交易日历(工作日)
参数:
start_date (str): 开始日期,格式'YYYYMMDD'
end_date (str): 结束日期,格式'YYYYMMDD'
update (bool): 是否更新交易日历默认为False
返回:
list: 交易日期列表
"""
# 先检查表是否存在
if not db_manager.table_exists('trade_cal') or update:
logger.debug(f"表 trade_cal 不存在")
# 自动拉取交易日历
try:
DataFetcher.get_basic('trade_cal')
except Exception as e:
logger.error(f"自动拉取交易日历时出错: {e}")
logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
return []
if start_date is None:
start_date = (datetime.now() - timedelta(days=30)).strftime('%Y%m%d')
if end_date is None:
end_date = datetime.now().strftime('%Y%m%d')
try:
df = db_manager.load_df_from_db('trade_cal', conditions=f"cal_date BETWEEN '{start_date}' AND '{end_date}'")
return df[df['is_open'] == 1]['cal_date'].tolist()
except Exception as e:
logger.error(f"获取交易日历时出错: {e}")
logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
return []
@staticmethod
def get_existing_table_trade_dates(table_name):
"""
查询指定表中已存在的交易日期
参数:
table_name (str): 数据表名称
返回:
set: 已存在于数据库中的交易日期集合
"""
# 先检查表是否存在
if not db_manager.table_exists(table_name):
logger.debug(f"'{table_name}' 不存在")
return []
try:
# 使用query方法获取不重复的交易日期
query_result = db_manager.query(f"SELECT DISTINCT trade_date FROM {table_name}")
# 将查询结果转换为集合并返回
return list(set(query_result['trade_date'].values))
except Exception as e:
logger.error(f"获取已存在交易日期时出错: {e}")
logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
return []
@staticmethod
def get_main_board_stocks():
"""
筛选主板且名称不包含ST的股票代码列表
返回:
list: 符合条件的股票代码(ts_code)列表
"""
# 先检查表是否存在
if not db_manager.table_exists('stock_basic'):
logger.debug("表 stock_basic 不存在")
# 可以在这里添加自动拉取股票基本信息的代码类似于get_trade_cal方法
return []
try:
# 查询条件:主板(market='主板')、名称不含ST
query = "SELECT ts_code FROM stock_basic WHERE market='主板' AND name NOT LIKE '%ST%'"
result = db_manager.query(query)
return result['ts_code'].tolist() if not result.empty else []
except Exception as e:
logger.error(f"获取主板上市非ST股票列表时出错: {e}")
logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
return []
@staticmethod
def get_table_data_by_date(table_name, start_date=None, end_date=None, filter_main_board=False):
"""
从数据库获取指定表的数据,并截取指定日期范围内的数据
参数:
table_name (str): 数据库表名
start_date (str): 开始日期,格式'YYYYMMDD'默认为30天前
end_date (str): 结束日期,格式'YYYYMMDD',默认为今天
filter_main_board (bool): 是否只返回主板上市股票的数据默认为False
error_prefix (str): 错误日志的前缀
返回:
pandas.DataFrame: 查询的数据
"""
# 先检查表是否存在
if not db_manager.table_exists(table_name):
logger.debug(f"{table_name} 不存在")
return pd.DataFrame()
# 设置默认日期范围
if start_date is None:
start_date = (datetime.now() - timedelta(days=30)).strftime('%Y%m%d')
if end_date is None:
end_date = datetime.now().strftime('%Y%m%d')
try:
# 基础条件:日期过滤
conditions = f"trade_date BETWEEN '{start_date}' AND '{end_date}'"
# 如果需要过滤主板股票
if filter_main_board:
# 获取主板上市非ST股票列表
main_board_stocks = DataReader.get_main_board_stocks()
if main_board_stocks:
# 将股票代码列表转换为SQL安全的格式
stock_codes = ", ".join([f"'{code}'" for code in main_board_stocks])
# 添加股票代码过滤条件
conditions += f" AND ts_code IN ({stock_codes})"
else:
logger.warning("未找到主板上市股票返回空DataFrame")
return pd.DataFrame()
# 查询数据
df = db_manager.load_df_from_db(table_name, conditions=conditions)
return df
except Exception as e:
logger.error(f"获取{table_name}数据时出错: {e}")
logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
return pd.DataFrame()
@staticmethod
def get_news(start_date=None, end_date=None, update=False):
# 先检查表是否存在
if not db_manager.table_exists('news') or update:
logger.debug(f"表 trade_cal 不存在")
# 自动拉取交易日历
try:
DataFetcher.get_news()
except Exception as e:
logger.error(f"自动拉去news时出错: {e}")
logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
return []
# 设置默认日期范围
if start_date is None:
start_date = '2025-04-15 00:00:00'
if end_date is None:
end_date = datetime.now().strftime('YYYY-MM-DD HH:MM:SS')
try:
# 查询数据
df = db_manager.load_df_from_db('news', conditions=f"datetime BETWEEN '{start_date}' AND '{end_date}'")
return df
except Exception as e:
logger.error(f"获取news数据时出错: {e}")
logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
return pd.DataFrame()