🛠️ **refactor**(data_downloader.py): 改进save_stock_dataupdate_metadata函数以增强灵活性与健壮性

 **feat**(data_downloader.py): 新增`perform_full_update`与`perform_incremental_update`函数以支持全量与增量更新
♻️ **style**(data_downloader.py): 添加完整函数文档与空值处理,提升代码可读性与容错性
This commit is contained in:
Qihang Zhang 2025-04-05 20:59:27 +08:00
parent 66116caf94
commit 1f8c553078
2 changed files with 470 additions and 86 deletions

View File

@ -281,8 +281,19 @@ def download_stock_data(pro, ts_code, start_date, end_date):
return result_df
def save_stock_data(df, engine, ts_code):
"""保存股票数据到对应的表中"""
def save_stock_data(df, engine, ts_code, if_exists='replace'):
"""
保存股票数据到对应的表中
Args:
df: 股票数据DataFrame
engine: 数据库引擎
ts_code: 股票代码
if_exists: 如果表已存在处理方式'replace': 替换现有表'append': 附加到现有表
Returns:
int: 保存的记录数量
"""
if df.empty:
print(f"警告:{ts_code}没有数据可保存")
return 0
@ -291,12 +302,12 @@ def save_stock_data(df, engine, ts_code):
if 'ts_code' in df.columns:
df = df.drop(columns=['ts_code'])
# 使用replace模式确保表中只包含最新数据
# 使用指定的模式保存数据
try:
symbol = ts_code.split('.')[0]
result = df.to_sql(f'{symbol}', engine, index=False,
if_exists='replace', chunksize=1000)
print(f"成功保存{result}{ts_code}股票数据")
if_exists=if_exists, chunksize=1000)
print(f"成功保存{result}{ts_code}股票数据 (模式: {if_exists})")
return result
except Exception as e:
print(f"保存{ts_code}数据时出错: {str(e)}")
@ -304,7 +315,17 @@ def save_stock_data(df, engine, ts_code):
def update_metadata(engine, ts_code):
"""更新股票元数据中的统计信息,从股票表中直接获取数据"""
"""
更新股票元数据中的统计信息从股票表中直接获取数据
Args:
engine: 数据库连接引擎
ts_code: 股票代码(带后缀如000001.SZ)
Returns:
bool: 更新成功返回True否则返回False
"""
try:
# 提取股票代码(不含后缀)
symbol = ts_code.split('.')[0]
@ -334,44 +355,89 @@ def update_metadata(engine, ts_code):
# 获取统计信息
stats_result = conn.execute(text(stats_query)).fetchone()
if not stats_result:
print(f"警告:{ts_code}没有数据,无法更新元数据")
return
print(f"警告:{ts_code}没有数据,将更新元数据状态为异常")
# 更新状态为异常
update_empty_sql = f"""
UPDATE stock_metadata
SET last_full_update = NOW(),
record_count = 0,
status = 2
WHERE ts_code = '{ts_code}'
"""
conn.execute(text(update_empty_sql))
conn.commit()
return False
# 处理日期字段
try:
data_start_date = stats_result[0].strftime('%Y-%m-%d') if stats_result[0] else 'NULL'
data_end_date = stats_result[1].strftime('%Y-%m-%d') if stats_result[1] else 'NULL'
record_count = stats_result[2]
except AttributeError:
# 处理日期可能是字符串而不是datetime对象的情况
data_start_date = stats_result[0] if stats_result[0] else 'NULL'
data_end_date = stats_result[1] if stats_result[1] else 'NULL'
record_count = stats_result[2] or 0
# 获取最新交易日数据
latest_data_result = conn.execute(text(latest_data_query)).fetchone()
latest_price = f"{latest_data_result[0]}" if latest_data_result and latest_data_result[
0] is not None else 'NULL'
latest_pe_ttm = f"{latest_data_result[1]}" if latest_data_result and latest_data_result[
1] is not None else 'NULL'
latest_pb = f"{latest_data_result[2]}" if latest_data_result and latest_data_result[2] is not None else 'NULL'
latest_total_mv = f"{latest_data_result[3]}" if latest_data_result and latest_data_result[
3] is not None else 'NULL'
# 加载更新SQL模板
with open('sql/update_metadata.sql', 'r', encoding='utf-8') as f:
update_sql_template = f.read()
# 设置默认值并处理NULL
default_value = 'NULL'
latest_price = default_value
latest_pe_ttm = default_value
latest_pb = default_value
latest_total_mv = default_value
# 填充模板
update_sql = update_sql_template.format(
data_start_date=data_start_date,
data_end_date=data_end_date,
record_count=record_count,
latest_price=latest_price,
latest_pe_ttm=latest_pe_ttm,
latest_pb=latest_pb,
latest_total_mv=latest_total_mv,
ts_code=ts_code
)
# 如果有最新数据,则更新相应字段
if latest_data_result:
latest_price = str(latest_data_result[0]) if latest_data_result[0] is not None else default_value
latest_pe_ttm = str(latest_data_result[1]) if latest_data_result[1] is not None else default_value
latest_pb = str(latest_data_result[2]) if latest_data_result[2] is not None else default_value
latest_total_mv = str(latest_data_result[3]) if latest_data_result[3] is not None else default_value
# 构建更新SQL
update_sql = f"""
UPDATE stock_metadata
SET
last_full_update = NOW(),
data_start_date = CASE WHEN '{data_start_date}' = 'NULL' THEN NULL ELSE '{data_start_date}' END,
data_end_date = CASE WHEN '{data_end_date}' = 'NULL' THEN NULL ELSE '{data_end_date}' END,
record_count = {record_count},
status = CASE WHEN {record_count} > 0 THEN 1 ELSE 2 END,
latest_price = CASE WHEN '{latest_price}' = 'NULL' THEN NULL ELSE {latest_price} END,
latest_date = CASE WHEN '{data_end_date}' = 'NULL' THEN NULL ELSE '{data_end_date}' END,
latest_pe_ttm = CASE WHEN '{latest_pe_ttm}' = 'NULL' THEN NULL ELSE {latest_pe_ttm} END,
latest_pb = CASE WHEN '{latest_pb}' = 'NULL' THEN NULL ELSE {latest_pb} END,
latest_total_mv = CASE WHEN '{latest_total_mv}' = 'NULL' THEN NULL ELSE {latest_total_mv} END
WHERE ts_code = '{ts_code}'
"""
# 执行更新
conn.execute(text(update_sql))
conn.commit()
print(f"已更新({ts_code})的元数据信息")
return True
except Exception as e:
print(f"更新({ts_code})元数据时出错: {str(e)}")
# 尝试将状态更新为异常
try:
with engine.connect() as conn:
error_update_sql = f"""
UPDATE stock_metadata
SET status = 2,
last_full_update = NOW()
WHERE ts_code = '{ts_code}'
"""
conn.execute(text(error_update_sql))
conn.commit()
except Exception as inner_e:
print(f"更新({ts_code})状态为异常时出错: {str(inner_e)}")
return False
def process_stock(engine, pro, ts_code, start_date, end_date):
@ -391,7 +457,16 @@ def process_stock(engine, pro, ts_code, start_date, end_date):
print(f"警告:({ts_code})没有获取到数据")
def main():
def perform_full_update(start_year=2020):
"""
执行全量更新从指定年份开始到今天的所有数据
Args:
start_year: 开始年份默认为2020年
"""
print(f"开始执行全量更新 (从{start_year}年至今)...")
start_time = datetime.now()
# 加载配置
config = load_config()
@ -402,22 +477,328 @@ def main():
ts.set_token(config['tushare_token'])
pro = ts.pro_api()
# 设置数据获取的日期范围
end_date = datetime.now().strftime('%Y%m%d')
start_date = f"{start_year}0101" # 从指定年份的1月1日开始
print(f"数据范围: {start_date}{end_date}")
# 从元数据表获取所有股票代码
with engine.connect() as conn:
# 查询所有需要更新的股票代码
query = """
SELECT ts_code
FROM stock_metadata
WHERE list_status = 'L' -- 只选择已上市的股票
"""
result = conn.execute(text(query))
stock_codes = [row[0] for row in result]
if not stock_codes:
print("没有找到需要更新的股票,请先确保元数据表已初始化")
return
print(f"共找到 {len(stock_codes)} 只股票需要更新")
# 记录成功和失败的股票数量
success_count = 0
failed_count = 0
skipped_count = 0
# 当前更新时间
update_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
# 逐一处理每只股票
total_stocks = len(stock_codes)
for index, ts_code in enumerate(stock_codes):
try:
print(f"[{index + 1}/{total_stocks}] 正在全量更新股票: {ts_code}")
# 检查股票表是否存在,不存在则创建
symbol = ts_code.split('.')[0]
with engine.connect() as conn:
table_exists_query = f"""
SELECT COUNT(*)
FROM information_schema.tables
WHERE table_schema = '{config['mysql']['database']}'
AND table_name = '{symbol}'
"""
table_exists = conn.execute(text(table_exists_query)).scalar() > 0
if not table_exists:
create_stock_table(engine, ts_code)
# 下载股票数据
stock_data = download_stock_data(pro, ts_code, start_date, end_date)
# 如果下载成功,保存数据并更新元数据
if not stock_data.empty:
# 使用replace模式保存数据全量替换
records_saved = save_stock_data(stock_data, engine, ts_code)
# 更新元数据统计信息
update_metadata(engine, ts_code)
# 只更新全量更新时间戳
with engine.connect() as conn:
update_status_sql = f"""
UPDATE stock_metadata
SET last_full_update = '{update_time}'
WHERE ts_code = '{ts_code}'
"""
conn.execute(text(update_status_sql))
conn.commit()
success_count += 1
print(f"成功更新 {ts_code},保存了 {records_saved} 条记录")
else:
failed_count += 1
print(f"警告:{ts_code} 没有获取到数据")
# 防止API限流每次请求后短暂休息
import time
time.sleep(0.5)
except Exception as e:
failed_count += 1
print(f"处理股票 {ts_code} 时出错: {str(e)}")
end_time = datetime.now()
duration = (end_time - start_time).total_seconds() / 60 # 转换为分钟
print("\n全量更新完成!")
print(f"总耗时: {duration:.2f} 分钟")
print(f"成功更新: {success_count} 只股票")
print(f"更新失败: {failed_count} 只股票")
print(f"总计: {total_stocks} 只股票")
def perform_incremental_update():
"""
执行增量更新从上次增量或全量更新日期到今天的数据
- 使用最近的更新日期增量或全量作为起点
- 如果表不存在跳过该股票
- 只更新增量更新时间字段
"""
print("开始执行增量更新...")
start_time = datetime.now()
# 加载配置
config = load_config()
# 创建数据库引擎
engine = create_engine_from_config(config)
# 设置Tushare API
ts.set_token(config['tushare_token'])
pro = ts.pro_api()
# 当前日期作为结束日期
end_date = datetime.now().strftime('%Y%m%d')
# 获取需要更新的股票及其上次更新日期
stocks_to_update = []
with engine.connect() as conn:
query = """
SELECT
ts_code,
last_incremental_update,
last_full_update
FROM
stock_metadata
WHERE
list_status = 'L' AND
status != 2 -- 排除状态异常的股票
"""
result = conn.execute(text(query))
for row in result:
ts_code = row[0]
last_incr_update = row[1] # 上次增量更新日期
last_full_update = row[2] # 上次全量更新日期
# 确定起始日期:使用最近的更新日期
latest_update = None
# 检查增量更新日期
if last_incr_update:
latest_update = last_incr_update
# 检查全量更新日期
if last_full_update:
if not latest_update or last_full_update > latest_update:
latest_update = last_full_update
# 如果有更新日期,使用其后一天作为起始日期
if latest_update:
start_date = (latest_update + timedelta(days=1)).strftime('%Y%m%d')
else:
# 如果没有任何更新记录默认从30天前开始
start_date = (datetime.now() - timedelta(days=30)).strftime('%Y%m%d')
# 如果起始日期大于等于结束日期,则跳过此股票
if start_date >= end_date:
print(f"股票 {ts_code} 数据已是最新,无需更新")
continue
stocks_to_update.append({
'ts_code': ts_code,
'start_date': start_date
})
if not stocks_to_update:
print("没有找到需要增量更新的股票")
return
print(f"共找到 {len(stocks_to_update)} 只股票需要增量更新")
# 记录成功和失败的股票数量
success_count = 0
failed_count = 0
no_new_data_count = 0
skipped_count = 0
# 当前更新时间
update_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
# 逐一处理每只股票
total_stocks = len(stocks_to_update)
for index, stock in enumerate(stocks_to_update):
ts_code = stock['ts_code']
start_date = stock['start_date']
try:
print(f"[{index + 1}/{total_stocks}] 正在增量更新股票: {ts_code} (从 {start_date}{end_date})")
# 确保股票表存在
symbol = ts_code.split('.')[0]
with engine.connect() as conn:
table_exists_query = f"""
SELECT COUNT(*)
FROM information_schema.tables
WHERE table_schema = '{config['mysql']['database']}'
AND table_name = '{symbol}'
"""
table_exists = conn.execute(text(table_exists_query)).scalar() > 0
if not table_exists:
print(f"股票 {ts_code} 数据表不存在,跳过此股票")
skipped_count += 1
continue # 表不存在则跳过
# 下载增量数据
new_data = download_stock_data(pro, ts_code, start_date, end_date)
if not new_data.empty:
# 获取现有数据
try:
with engine.connect() as conn:
existing_data = pd.read_sql(f"SELECT * FROM `{symbol}`", conn)
except Exception as e:
print(f"读取现有数据失败: {str(e)},跳过此股票")
skipped_count += 1
continue
if not existing_data.empty:
# 转换日期列为相同的格式以便合并
existing_data['trade_date'] = pd.to_datetime(existing_data['trade_date'])
new_data['trade_date'] = pd.to_datetime(new_data['trade_date'])
# 删除可能重复的日期记录
existing_dates = set(existing_data['trade_date'])
new_data = new_data[~new_data['trade_date'].isin(existing_dates)]
if new_data.empty:
print(f"股票 {ts_code} 没有新数据需要更新")
no_new_data_count += 1
continue
# 合并数据
combined_data = pd.concat([existing_data, new_data], ignore_index=True)
# 按日期排序
combined_data = combined_data.sort_values('trade_date')
# 保存合并后的数据
records_saved = save_stock_data(combined_data, engine, ts_code)
else:
# 如果表存在但为空,直接保存新数据
records_saved = save_stock_data(new_data, engine, ts_code)
# 更新元数据统计信息
update_metadata(engine, ts_code)
# 只更新增量更新时间戳
with engine.connect() as conn:
update_status_sql = f"""
UPDATE stock_metadata
SET last_incremental_update = '{update_time}'
WHERE ts_code = '{ts_code}'
"""
conn.execute(text(update_status_sql))
conn.commit()
success_count += 1
print(f"成功增量更新 {ts_code},新增 {len(new_data)} 条记录")
else:
print(f"股票 {ts_code} 在指定时间范围内没有新数据")
no_new_data_count += 1
# 防止API限流每次请求后短暂休息
import time
time.sleep(0.5)
except Exception as e:
failed_count += 1
print(f"增量更新股票 {ts_code} 时出错: {str(e)}")
end_time = datetime.now()
duration = (end_time - start_time).total_seconds() / 60 # 转换为分钟
print("\n增量更新完成!")
print(f"总耗时: {duration:.2f} 分钟")
print(f"成功更新: {success_count} 只股票")
print(f"无新数据: {no_new_data_count} 只股票")
print(f"跳过股票: {skipped_count} 只股票")
print(f"更新失败: {failed_count} 只股票")
print(f"总计: {total_stocks} 只股票")
def main():
"""主函数,允许用户选择更新方式"""
import argparse
# 命令行参数解析
parser = argparse.ArgumentParser(description='股票数据更新工具')
parser.add_argument('--init', action='store_true', help='初始化元数据表')
parser.add_argument('--mode', choices=['full', 'incremental', 'both'],
default='full', help='更新模式: full=全量更新, incremental=增量更新, both=两者都执行')
parser.add_argument('--year', type=int, default=2020, help='全量更新的起始年份')
args = parser.parse_args()
# 如果需要初始化元数据
if args.init:
# 加载配置
config = load_config()
# 创建数据库引擎
engine = create_engine_from_config(config)
# 设置Tushare API
ts.set_token(config['tushare_token'])
pro = ts.pro_api()
# 创建元数据表
create_metadata_table(engine)
# 下载并保存股票元数据
metadata_df = download_stock_metadata(pro)
save_metadata(metadata_df, engine)
print("元数据初始化完成")
# 示例处理000001股票
ts_code = '000001.SZ' # 平安银行的ts_code
# 根据选择的模式执行更新
if args.mode == 'full' or args.mode == 'both':
perform_full_update(args.year)
# 设置数据获取的日期范围示例过去5年的数据
end_date = datetime.now().strftime('%Y%m%d')
start_date = (datetime.now() - timedelta(days=15 * 365)).strftime('%Y%m%d')
# 处理该股票
process_stock(engine, pro, ts_code, start_date, end_date)
if args.mode == 'incremental' or args.mode == 'both':
perform_incremental_update()
if __name__ == '__main__':

View File

@ -2,13 +2,16 @@
UPDATE stock_metadata
SET
last_full_update = NOW(),
data_start_date = '{data_start_date}',
data_end_date = '{data_end_date}',
data_start_date = CASE WHEN '{data_start_date}' = 'NULL' THEN NULL ELSE '{data_start_date}' END,
data_end_date = CASE WHEN '{data_end_date}' = 'NULL' THEN NULL ELSE '{data_end_date}' END,
record_count = {record_count},
status = 1,
latest_price = {latest_price},
latest_date = '{data_end_date}',
latest_pe_ttm = {latest_pe_ttm},
latest_pb = {latest_pb},
latest_total_mv = {latest_total_mv}
status = CASE
WHEN {record_count} > 0 THEN 1 -- 有数据,状态为正常
ELSE 2 -- 无数据,状态为异常
END,
latest_price = CASE WHEN '{latest_price}' = 'NULL' THEN NULL ELSE {latest_price} END,
latest_date = CASE WHEN '{data_end_date}' = 'NULL' THEN NULL ELSE '{data_end_date}' END,
latest_pe_ttm = CASE WHEN '{latest_pe_ttm}' = 'NULL' THEN NULL ELSE {latest_pe_ttm} END,
latest_pb = CASE WHEN '{latest_pb}' = 'NULL' THEN NULL ELSE {latest_pb} END,
latest_total_mv = CASE WHEN '{latest_total_mv}' = 'NULL' THEN NULL ELSE {latest_total_mv} END
WHERE ts_code = '{ts_code}'