496 lines
21 KiB
Python
496 lines
21 KiB
Python
import traceback
|
||
from datetime import datetime, timedelta
|
||
from time import sleep
|
||
|
||
import pandas as pd
|
||
import tushare as ts
|
||
from tqdm import tqdm
|
||
|
||
from config_manager import get_config_manager
|
||
from database_manager import DatabaseManager
|
||
from logger_manager import get_logger
|
||
|
||
# 模块级别的配置
|
||
config = get_config_manager()
|
||
ts.set_token(config.get('tushare_token'))
|
||
pro = ts.pro_api()
|
||
|
||
# 初始化数据库管理器
|
||
db_manager = DatabaseManager()
|
||
|
||
# 获取日志器
|
||
logger = get_logger()
|
||
|
||
|
||
class DataFetcher:
|
||
"""
|
||
数据获取器类,负责从Tushare获取各类数据并管理本地缓存
|
||
"""
|
||
|
||
@staticmethod
|
||
def get_basic(api_name):
|
||
"""
|
||
获取基础数据,如股票列表等
|
||
参数:
|
||
api_name (str): Tushare的API名称,例如'stock_basic'
|
||
返回:
|
||
pandas.DataFrame: 请求的数据
|
||
"""
|
||
# 确保pro中存在对应的API方法
|
||
if not hasattr(pro, api_name):
|
||
logger.error(f"Tushare API '{api_name}'不存在")
|
||
return False
|
||
|
||
try:
|
||
df = getattr(pro, api_name)()
|
||
# 将数据保存到数据库
|
||
db_manager.save_df_to_db(df, table_name=api_name, if_exists='replace')
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"获取基础数据时出错: {e}")
|
||
logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
|
||
return False
|
||
|
||
@staticmethod
|
||
def get_trade_date(api_name, start_date=None, end_date=None, force_update=False, batch_size=100000):
|
||
"""
|
||
获取指定时间段内的Tushare数据,使用数据库缓存,并分批处理大量数据
|
||
|
||
参数:
|
||
api_name (str): Tushare的API名称,例如'moneyflow'或'moneyflow_ind_dc'
|
||
start_date (str): 开始日期,格式'YYYYMMDD'
|
||
end_date (str): 结束日期,格式'YYYYMMDD'
|
||
force_update (bool): 是否强制更新所选区域数据,默认为False
|
||
batch_size (int): 每批处理的最大行数,默认为100000
|
||
|
||
返回:
|
||
pandas.DataFrame: 请求的数据
|
||
"""
|
||
# 使用api_name作为表key查询表名
|
||
table_name = api_name
|
||
|
||
# 确保pro中存在对应的API方法
|
||
if not hasattr(pro, api_name):
|
||
logger.error(f"Tushare API '{api_name}'不存在")
|
||
return False
|
||
|
||
# 获取目标交易日历
|
||
all_trade_dates = DataReader.get_trade_cal(start_date, end_date)
|
||
|
||
# 确定需要获取的日期
|
||
if not force_update:
|
||
# 从数据库获取已有的交易日期
|
||
existing_dates = DataReader.get_existing_table_trade_dates(table_name=table_name)
|
||
# 筛选出需要新获取的日期
|
||
dates_to_fetch = [date for date in all_trade_dates if date not in existing_dates]
|
||
else:
|
||
# 强制更新,获取所有日期数据
|
||
dates_to_fetch = all_trade_dates
|
||
logger.info(f"强制更新模式: 将更新 {len(dates_to_fetch)} 个交易日的数据")
|
||
|
||
if not dates_to_fetch:
|
||
logger.info("所有数据已在数据库中,无需更新")
|
||
return True
|
||
logger.info(f"需要获取 {len(dates_to_fetch)} 个交易日的数据")
|
||
|
||
# 分批处理数据
|
||
temp_batch = []
|
||
total_rows = 0
|
||
|
||
# 获取数据
|
||
for trade_date in tqdm(dates_to_fetch):
|
||
try:
|
||
# 动态调用Tushare API
|
||
api_method = getattr(pro, api_name)
|
||
df = api_method(trade_date=trade_date)
|
||
|
||
if not df.empty:
|
||
temp_batch.append(df)
|
||
total_rows += len(df)
|
||
|
||
# 当累积的数据量达到batch_size时,进行一次批量写入
|
||
if total_rows >= batch_size:
|
||
DataFetcher.process_batch(temp_batch, table_name, force_update)
|
||
logger.info(f"已处理 {total_rows} 行数据")
|
||
# 重置临时批次
|
||
temp_batch = []
|
||
total_rows = 0
|
||
else:
|
||
logger.info(f"日期 {trade_date} 无数据")
|
||
except Exception as e:
|
||
logger.error(f"获取 {trade_date} 的数据时出错: {e}")
|
||
logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
|
||
sleep(5)
|
||
|
||
# 处理剩余的数据
|
||
if temp_batch:
|
||
DataFetcher.process_batch(temp_batch, table_name, force_update)
|
||
logger.info(f"已处理剩余 {total_rows} 行数据")
|
||
|
||
logger.info("数据获取与处理完成")
|
||
return True
|
||
|
||
@staticmethod
|
||
def process_batch(batch_dfs, table_name, force_update):
|
||
"""
|
||
处理一批数据
|
||
参数:
|
||
batch_dfs (list): DataFrame列表
|
||
table_name (str): 数据库表名称
|
||
force_update (bool): 是否强制更新
|
||
"""
|
||
if not batch_dfs:
|
||
return
|
||
# 合并批次中的所有DataFrame
|
||
batch_df = pd.concat(batch_dfs, ignore_index=True)
|
||
# 保存数据,传入force_update参数
|
||
db_manager.save_df_to_db(batch_df, table_name=table_name, force_update=force_update)
|
||
|
||
@staticmethod
|
||
def update_all(trade_date_api=None, basic_api=None, start_date=None, end_date=None, force_update=False):
|
||
"""
|
||
更新所有数据
|
||
"""
|
||
# 获取所有API名称
|
||
if trade_date_api is None:
|
||
trade_date_api = [
|
||
'moneyflow', # 个股资金流向
|
||
'moneyflow_ind_dc', # 东财概念及行业板块资金流向(DC)
|
||
'daily', # A股日线行情
|
||
'daily_basic', # 每日指标,获取全部股票每日重要的基本面指标
|
||
'stk_limit', # 每日涨跌停价格
|
||
'cyq_perf', # 每日筹码及胜率
|
||
'moneyflow_ths', # 同花顺资金流向 20241219以前就没了
|
||
'moneyflow_dc', # 东方财富资金流向
|
||
'moneyflow_cnt_ths',# 同花顺概念板块资金流向(THS) 20240120以前就没了
|
||
'moneyflow_ind_ths',# 同花顺行业板块资金流向(THS) 20250120以前就没了
|
||
'kpl_concept', # 开盘啦题材库,获取开盘啦概念题材列表 20241020以前就没了
|
||
'kpl_concept_cons', # 开盘啦题材成分,获取开盘啦概念题材的成分股 20241014以前就没了
|
||
'kpl_list', # 获取开盘啦涨停、跌停、炸板等榜单数据
|
||
'top_list', # 龙虎榜每日明细
|
||
'top_inst', # 龙虎榜机构席位明细
|
||
'limit_list_d', # 涨跌停列表(新),获取A股每日涨跌停、炸板数据情况,数据从2020年开始(不提供ST股票的统计)
|
||
'ths_daily', # 同花顺板块指数行情
|
||
'dc_index', # 东方财富概念板块,获取东方财富每个交易日的概念板块数据,支持按日期查询
|
||
'stk_auction', # 当日集合竞价,获取当日个股和ETF的集合竞价成交情况,每天9点25后可以获取当日的集合竞价成交数据
|
||
'ths_hot', # 获取同花顺App热榜数据,包括热股、概念板块、ETF、可转债、港美股等等,每日盘中提取4次,收盘后4次,最晚22点提取一次。
|
||
|
||
]
|
||
if basic_api is None:
|
||
basic_api = [
|
||
'stock_basic', # 股票基本信息
|
||
'trade_cal', # 交易日历
|
||
'namechange', # 股票曾用名
|
||
'ths_index', # 同花顺概念和行业指数
|
||
'hm_list', # 游资名录
|
||
'index_basic', # 指数基本信息
|
||
]
|
||
|
||
# 使用get_trade_date更新trade_date_api列表中的所有API
|
||
for api in trade_date_api:
|
||
logger.info(f"更新API: {api}")
|
||
DataFetcher.get_trade_date(api_name=api, start_date=start_date, end_date=end_date, force_update=force_update)
|
||
|
||
# 使用get_basic更新basic_api列表中的所有API
|
||
for api in basic_api:
|
||
logger.info(f"更新API: {api}")
|
||
DataFetcher.get_basic(api_name=api)
|
||
|
||
# 更新新闻数据
|
||
logger.info("更新新闻数据")
|
||
DataFetcher.get_news()
|
||
|
||
@staticmethod
|
||
def get_news(src='sina', fields=None, start_date=None, end_date=None):
|
||
"""
|
||
智能获取新闻数据,可获取指定时间范围内的数据
|
||
参数:
|
||
src (str): 新闻来源,如'sina'、'wallstreetcn'等
|
||
fields (str): 需要获取的字段,默认为'datetime,title,channels,content'
|
||
start_date (str): 可选的开始日期,格式'YYYY-MM-DD HH:MM:SS'
|
||
end_date (str): 可选的结束日期,格式'YYYY-MM-DD HH:MM:SS',默认为当前时间
|
||
返回:
|
||
bool: 是否成功获取数据
|
||
"""
|
||
table_name = 'news'
|
||
# 固定的历史起始点
|
||
HISTORY_START = '2025-04-01 00:00:00'
|
||
|
||
# 如果未指定字段,设置默认字段
|
||
if fields is None:
|
||
fields = 'datetime,title,channels,content'
|
||
|
||
# 确定结束时间
|
||
if end_date is None:
|
||
end_date = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||
|
||
# 确定开始时间
|
||
if start_date:
|
||
logger.info(f"使用用户指定的开始日期: {start_date}")
|
||
# 如果未指定开始日期,使用智能逻辑确定
|
||
elif not db_manager.table_exists(table_name):
|
||
logger.info(f"表 {table_name} 不存在,将获取从 {HISTORY_START} 至 {end_date} 的新闻数据")
|
||
start_date = HISTORY_START
|
||
else:
|
||
# 表存在,查询符合条件的最新的新闻时间
|
||
try:
|
||
# 查询小于等于end_date的最新新闻时间
|
||
newest_query = f"SELECT MAX(datetime) as max_time FROM {table_name} WHERE datetime <= '{end_date}'"
|
||
newest_result = db_manager.query(newest_query)
|
||
# 如果有数据并且值不为None
|
||
if not newest_result.empty and newest_result['max_time'].iloc[0] is not None:
|
||
newest_time = newest_result['max_time'].iloc[0]
|
||
logger.info(f"数据库中小于等于 {end_date} 的最新新闻时间为: {newest_time}")
|
||
# 从最新时间开始获取新数据
|
||
start_date = newest_time
|
||
else:
|
||
# 数据库表存在但没有符合条件的数据
|
||
logger.info(
|
||
f"数据库表 {table_name} 存在但没有小于等于 {end_date} 的数据,将获取从 {HISTORY_START} 至 {end_date} 的新闻数据")
|
||
start_date = HISTORY_START
|
||
except Exception as e:
|
||
logger.error(f"查询最新新闻时间时出错: {e}")
|
||
logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
|
||
# 出错时,从固定起始点获取
|
||
start_date = HISTORY_START
|
||
|
||
# 检查时间范围有效性
|
||
start_datetime = datetime.strptime(start_date, '%Y-%m-%d %H:%M:%S')
|
||
end_datetime = datetime.strptime(end_date, '%Y-%m-%d %H:%M:%S')
|
||
|
||
if start_datetime >= end_datetime:
|
||
logger.info(f"开始时间 {start_date} 不早于结束时间 {end_date},无需获取数据")
|
||
return True
|
||
|
||
logger.info(f"开始获取 {start_date} 至 {end_date} 的新闻数据")
|
||
return DataFetcher._fetch_news_recursive(src, start_date, end_date, fields, table_name)
|
||
|
||
@staticmethod
|
||
def _fetch_news_recursive(src, start_date, end_date, fields, table_name, batch_count=0):
|
||
"""
|
||
递归获取新闻数据,处理API返回数据限制
|
||
|
||
参数:
|
||
src (str): 新闻来源
|
||
start_date (str): 开始日期,格式'YYYY-MM-DD HH:MM:SS'
|
||
end_date (str): 结束日期,格式'YYYY-MM-DD HH:MM:SS'
|
||
fields (str): 需要获取的字段
|
||
table_name (str): 数据库表名
|
||
batch_count (int): 批次计数,用于日志
|
||
|
||
返回:
|
||
bool: 是否成功获取数据
|
||
"""
|
||
try:
|
||
logger.info(f"获取新闻数据批次 #{batch_count + 1}: {start_date} 至 {end_date}")
|
||
|
||
# 调用Tushare API获取新闻数据
|
||
df = pro.news(src=src, start_date=start_date, end_date=end_date, fields=fields)
|
||
|
||
if df.empty:
|
||
logger.info(f"当前时间范围内无新闻数据")
|
||
return True
|
||
|
||
# 保存数据到数据库
|
||
db_manager.save_df_to_db(df, table_name=table_name, if_exists='append')
|
||
logger.info(f"成功保存 {len(df)} 条新闻数据")
|
||
|
||
# 如果返回的数据接近限制数量(1500条),可能还有更多数据
|
||
if len(df) >= 1400: # 接近限制,设置一个略小的值作为阈值
|
||
# 找到当前批次中最早的新闻时间
|
||
earliest_date = df['datetime'].min()
|
||
|
||
# 检查是否已经达到了期望的起始日期
|
||
original_start = datetime.strptime(start_date, '%Y-%m-%d %H:%M:%S')
|
||
current_earliest = datetime.strptime(earliest_date, '%Y-%m-%d %H:%M:%S')
|
||
|
||
if current_earliest <= original_start:
|
||
logger.info(f"已获取到指定开始日期的数据,获取完成")
|
||
return True
|
||
|
||
# 将最早日期作为新的结束日期,继续获取更早的数据
|
||
new_end_date = earliest_date
|
||
|
||
# 递归调用,获取更早的数据
|
||
logger.info(f"当前批次数据量接近API限制,继续获取更早数据")
|
||
sleep(1) # 避免频繁调用API
|
||
return DataFetcher._fetch_news_recursive(src, start_date, new_end_date, fields,
|
||
table_name, batch_count + 1)
|
||
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取新闻数据时出错: {e}")
|
||
logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
|
||
return False
|
||
|
||
|
||
class DataReader:
|
||
"""
|
||
数据读取器类,负责从数据库读取数据
|
||
"""
|
||
@staticmethod
|
||
def get_trade_cal(start_date=None, end_date=None, update=False):
|
||
"""
|
||
从数据库获取交易日历,如果表不存在则自动拉取数据,并截取指定日期范围内的交易日历(工作日)
|
||
参数:
|
||
start_date (str): 开始日期,格式'YYYYMMDD'
|
||
end_date (str): 结束日期,格式'YYYYMMDD'
|
||
update (bool): 是否更新交易日历,默认为False
|
||
返回:
|
||
list: 交易日期列表
|
||
"""
|
||
# 先检查表是否存在
|
||
if not db_manager.table_exists('trade_cal') or update:
|
||
logger.debug(f"表 trade_cal 不存在")
|
||
# 自动拉取交易日历
|
||
try:
|
||
DataFetcher.get_basic('trade_cal')
|
||
except Exception as e:
|
||
logger.error(f"自动拉取交易日历时出错: {e}")
|
||
logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
|
||
return []
|
||
|
||
|
||
if start_date is None:
|
||
start_date = (datetime.now() - timedelta(days=30)).strftime('%Y%m%d')
|
||
if end_date is None:
|
||
end_date = datetime.now().strftime('%Y%m%d')
|
||
try:
|
||
df = db_manager.load_df_from_db('trade_cal', conditions=f"cal_date BETWEEN '{start_date}' AND '{end_date}'")
|
||
return df[df['is_open'] == 1]['cal_date'].tolist()
|
||
except Exception as e:
|
||
logger.error(f"获取交易日历时出错: {e}")
|
||
logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
|
||
return []
|
||
|
||
@staticmethod
|
||
def get_existing_table_trade_dates(table_name):
|
||
"""
|
||
查询指定表中已存在的交易日期
|
||
参数:
|
||
table_name (str): 数据表名称
|
||
返回:
|
||
set: 已存在于数据库中的交易日期集合
|
||
"""
|
||
# 先检查表是否存在
|
||
if not db_manager.table_exists(table_name):
|
||
logger.debug(f"表 '{table_name}' 不存在")
|
||
return []
|
||
|
||
try:
|
||
# 使用query方法获取不重复的交易日期
|
||
query_result = db_manager.query(f"SELECT DISTINCT trade_date FROM {table_name}")
|
||
# 将查询结果转换为集合并返回
|
||
return list(set(query_result['trade_date'].values))
|
||
except Exception as e:
|
||
logger.error(f"获取已存在交易日期时出错: {e}")
|
||
logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
|
||
return []
|
||
|
||
@staticmethod
|
||
def get_main_board_stocks():
|
||
"""
|
||
筛选主板且名称不包含ST的股票代码列表
|
||
|
||
返回:
|
||
list: 符合条件的股票代码(ts_code)列表
|
||
"""
|
||
# 先检查表是否存在
|
||
if not db_manager.table_exists('stock_basic'):
|
||
logger.debug("表 stock_basic 不存在")
|
||
# 可以在这里添加自动拉取股票基本信息的代码,类似于get_trade_cal方法
|
||
return []
|
||
|
||
try:
|
||
# 查询条件:主板(market='主板')、名称不含ST
|
||
query = "SELECT ts_code FROM stock_basic WHERE market='主板' AND name NOT LIKE '%ST%'"
|
||
result = db_manager.query(query)
|
||
return result['ts_code'].tolist() if not result.empty else []
|
||
except Exception as e:
|
||
logger.error(f"获取主板上市非ST股票列表时出错: {e}")
|
||
logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
|
||
return []
|
||
|
||
@staticmethod
|
||
def get_table_data_by_date(table_name, start_date=None, end_date=None, filter_main_board=False):
|
||
"""
|
||
从数据库获取指定表的数据,并截取指定日期范围内的数据
|
||
|
||
参数:
|
||
table_name (str): 数据库表名
|
||
start_date (str): 开始日期,格式'YYYYMMDD',默认为30天前
|
||
end_date (str): 结束日期,格式'YYYYMMDD',默认为今天
|
||
filter_main_board (bool): 是否只返回主板上市股票的数据,默认为False
|
||
error_prefix (str): 错误日志的前缀
|
||
|
||
返回:
|
||
pandas.DataFrame: 查询的数据
|
||
"""
|
||
# 先检查表是否存在
|
||
if not db_manager.table_exists(table_name):
|
||
logger.debug(f"表 {table_name} 不存在")
|
||
return pd.DataFrame()
|
||
|
||
# 设置默认日期范围
|
||
if start_date is None:
|
||
start_date = (datetime.now() - timedelta(days=30)).strftime('%Y%m%d')
|
||
if end_date is None:
|
||
end_date = datetime.now().strftime('%Y%m%d')
|
||
|
||
try:
|
||
# 基础条件:日期过滤
|
||
conditions = f"trade_date BETWEEN '{start_date}' AND '{end_date}'"
|
||
|
||
# 如果需要过滤主板股票
|
||
if filter_main_board:
|
||
# 获取主板上市非ST股票列表
|
||
main_board_stocks = DataReader.get_main_board_stocks()
|
||
if main_board_stocks:
|
||
# 将股票代码列表转换为SQL安全的格式
|
||
stock_codes = ", ".join([f"'{code}'" for code in main_board_stocks])
|
||
# 添加股票代码过滤条件
|
||
conditions += f" AND ts_code IN ({stock_codes})"
|
||
else:
|
||
logger.warning("未找到主板上市股票,返回空DataFrame")
|
||
return pd.DataFrame()
|
||
|
||
# 查询数据
|
||
df = db_manager.load_df_from_db(table_name, conditions=conditions)
|
||
return df
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取{table_name}数据时出错: {e}")
|
||
logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
|
||
return pd.DataFrame()
|
||
|
||
@staticmethod
|
||
def get_news(start_date=None, end_date=None, update=False):
|
||
# 先检查表是否存在
|
||
if not db_manager.table_exists('news') or update:
|
||
logger.debug(f"表 trade_cal 不存在")
|
||
# 自动拉取交易日历
|
||
try:
|
||
DataFetcher.get_news()
|
||
except Exception as e:
|
||
logger.error(f"自动拉去news时出错: {e}")
|
||
logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
|
||
return []
|
||
|
||
# 设置默认日期范围
|
||
if start_date is None:
|
||
start_date = '2025-04-15 00:00:00'
|
||
if end_date is None:
|
||
end_date = datetime.now().strftime('YYYY-MM-DD HH:MM:SS')
|
||
|
||
|
||
try:
|
||
# 查询数据
|
||
df = db_manager.load_df_from_db('news', conditions=f"datetime BETWEEN '{start_date}' AND '{end_date}'")
|
||
return df
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取news数据时出错: {e}")
|
||
logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
|
||
return pd.DataFrame()
|