🚀 feat(data_downloader): 增加多进程支持,优化全量与增量更新功能
This commit is contained in:
parent
1f8c553078
commit
741a7c2d34
@ -2,6 +2,9 @@ import pandas as pd
|
||||
import tushare as ts
|
||||
import numpy as np
|
||||
import os
|
||||
import multiprocessing
|
||||
from multiprocessing import Pool
|
||||
from functools import partial
|
||||
from sqlalchemy import create_engine, text
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
@ -314,21 +317,346 @@ def save_stock_data(df, engine, ts_code, if_exists='replace'):
|
||||
return 0
|
||||
|
||||
|
||||
def split_list_into_chunks(data_list, num_chunks):
|
||||
"""
|
||||
将列表分割成指定数量的块
|
||||
|
||||
Args:
|
||||
data_list: 要分割的列表
|
||||
num_chunks: 块的数量
|
||||
|
||||
Returns:
|
||||
list: 包含多个子列表的列表
|
||||
"""
|
||||
avg = len(data_list) / float(num_chunks)
|
||||
chunks = []
|
||||
last = 0.0
|
||||
|
||||
while last < len(data_list):
|
||||
chunks.append(data_list[int(last):int(last + avg)])
|
||||
last += avg
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
def process_stock_batch(batch, config, update_mode, start_date, end_date, update_time):
|
||||
"""
|
||||
处理一批股票的工作函数(在单独进程中运行)
|
||||
|
||||
Args:
|
||||
batch: 要处理的股票列表
|
||||
config: 配置信息
|
||||
update_mode: 更新模式 'full' 或 'incremental'
|
||||
start_date: 开始日期
|
||||
end_date: 结束日期
|
||||
update_time: 当前更新时间戳
|
||||
|
||||
Returns:
|
||||
dict: 包含处理结果统计信息的字典
|
||||
"""
|
||||
# 为当前进程创建新的数据库连接和API连接
|
||||
engine = create_engine_from_config(config)
|
||||
ts.set_token(config['tushare_token'])
|
||||
pro = ts.pro_api()
|
||||
|
||||
# 结果统计
|
||||
results = {
|
||||
'success_count': 0,
|
||||
'failed_count': 0,
|
||||
'no_new_data_count': 0,
|
||||
'skipped_count': 0,
|
||||
'processed_count': 0
|
||||
}
|
||||
|
||||
# 获取当前进程ID
|
||||
process_id = multiprocessing.current_process().name
|
||||
total_stocks = len(batch)
|
||||
|
||||
for i, stock_info in enumerate(batch):
|
||||
# 根据更新模式获取不同的参数
|
||||
if update_mode == 'full':
|
||||
ts_code = stock_info
|
||||
try:
|
||||
# 简化日志,只显示进度
|
||||
print(f"[{process_id}] 进度 {i + 1}/{total_stocks}: 全量更新 {ts_code}")
|
||||
|
||||
# 检查股票表是否存在,不存在则创建
|
||||
symbol = ts_code.split('.')[0]
|
||||
with engine.connect() as conn:
|
||||
table_exists_query = f"""
|
||||
SELECT COUNT(*)
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = '{config['mysql']['database']}'
|
||||
AND table_name = '{symbol}'
|
||||
"""
|
||||
table_exists = conn.execute(text(table_exists_query)).scalar() > 0
|
||||
|
||||
if not table_exists:
|
||||
create_stock_table(engine, ts_code)
|
||||
|
||||
# 下载股票数据
|
||||
stock_data = download_stock_data(pro, ts_code, start_date, end_date)
|
||||
|
||||
# 如果下载成功,保存数据并更新元数据
|
||||
if not stock_data.empty:
|
||||
# 使用replace模式保存数据(全量替换)
|
||||
records_saved = save_stock_data(stock_data, engine, ts_code)
|
||||
# 更新元数据统计信息
|
||||
update_metadata(engine, ts_code)
|
||||
# 更新全量更新时间戳和状态为正常(1)
|
||||
with engine.connect() as conn:
|
||||
update_status_sql = f"""
|
||||
UPDATE stock_metadata
|
||||
SET last_full_update = '{update_time}',
|
||||
status = 1
|
||||
WHERE ts_code = '{ts_code}'
|
||||
"""
|
||||
conn.execute(text(update_status_sql))
|
||||
conn.commit()
|
||||
results['success_count'] += 1
|
||||
else:
|
||||
# 更新状态为全量更新失败(2)
|
||||
with engine.connect() as conn:
|
||||
update_status_sql = f"""
|
||||
UPDATE stock_metadata
|
||||
SET status = 2
|
||||
WHERE ts_code = '{ts_code}'
|
||||
"""
|
||||
conn.execute(text(update_status_sql))
|
||||
conn.commit()
|
||||
results['failed_count'] += 1
|
||||
|
||||
# 防止API限流
|
||||
import time
|
||||
time.sleep(0.5)
|
||||
|
||||
except Exception as e:
|
||||
# 发生异常时,更新状态为全量更新失败(2)
|
||||
try:
|
||||
with engine.connect() as conn:
|
||||
update_status_sql = f"""
|
||||
UPDATE stock_metadata
|
||||
SET status = 2
|
||||
WHERE ts_code = '{ts_code}'
|
||||
"""
|
||||
conn.execute(text(update_status_sql))
|
||||
conn.commit()
|
||||
except Exception:
|
||||
pass # 更新状态失败时,忽略继续执行
|
||||
|
||||
results['failed_count'] += 1
|
||||
print(f"[{process_id}] 股票 {ts_code} 更新失败: {str(e)}")
|
||||
|
||||
elif update_mode == 'incremental':
|
||||
ts_code = stock_info['ts_code']
|
||||
start_date = stock_info['start_date']
|
||||
|
||||
try:
|
||||
# 简化日志,只显示进度
|
||||
print(f"[{process_id}] 进度 {i + 1}/{total_stocks}: 增量更新 {ts_code}")
|
||||
|
||||
# 确保股票表存在
|
||||
symbol = ts_code.split('.')[0]
|
||||
with engine.connect() as conn:
|
||||
table_exists_query = f"""
|
||||
SELECT COUNT(*)
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = '{config['mysql']['database']}'
|
||||
AND table_name = '{symbol}'
|
||||
"""
|
||||
table_exists = conn.execute(text(table_exists_query)).scalar() > 0
|
||||
|
||||
if not table_exists:
|
||||
results['skipped_count'] += 1
|
||||
continue # 表不存在则跳过
|
||||
|
||||
# 下载增量数据
|
||||
new_data = download_stock_data(pro, ts_code, start_date, end_date)
|
||||
|
||||
if not new_data.empty:
|
||||
# 获取现有数据
|
||||
try:
|
||||
with engine.connect() as conn:
|
||||
existing_data = pd.read_sql(f"SELECT * FROM `{symbol}`", conn)
|
||||
except Exception as e:
|
||||
results['skipped_count'] += 1
|
||||
continue
|
||||
|
||||
if not existing_data.empty:
|
||||
# 转换日期列为相同的格式以便合并
|
||||
existing_data['trade_date'] = pd.to_datetime(existing_data['trade_date'])
|
||||
new_data['trade_date'] = pd.to_datetime(new_data['trade_date'])
|
||||
|
||||
# 删除可能重复的日期记录
|
||||
existing_dates = set(existing_data['trade_date'])
|
||||
new_data = new_data[~new_data['trade_date'].isin(existing_dates)]
|
||||
|
||||
if new_data.empty:
|
||||
results['no_new_data_count'] += 1
|
||||
continue
|
||||
|
||||
# 合并数据
|
||||
combined_data = pd.concat([existing_data, new_data], ignore_index=True)
|
||||
|
||||
# 按日期排序
|
||||
combined_data = combined_data.sort_values('trade_date')
|
||||
|
||||
# 保存合并后的数据
|
||||
records_saved = save_stock_data(combined_data, engine, ts_code)
|
||||
else:
|
||||
# 如果表存在但为空,直接保存新数据
|
||||
records_saved = save_stock_data(new_data, engine, ts_code)
|
||||
|
||||
# 更新元数据统计信息
|
||||
update_metadata(engine, ts_code)
|
||||
|
||||
# 更新增量更新时间戳和状态为正常(1)
|
||||
with engine.connect() as conn:
|
||||
update_status_sql = f"""
|
||||
UPDATE stock_metadata
|
||||
SET last_incremental_update = '{update_time}',
|
||||
status = 1
|
||||
WHERE ts_code = '{ts_code}'
|
||||
"""
|
||||
conn.execute(text(update_status_sql))
|
||||
conn.commit()
|
||||
|
||||
results['success_count'] += 1
|
||||
else:
|
||||
# 无新数据
|
||||
results['no_new_data_count'] += 1
|
||||
|
||||
# 防止API限流
|
||||
import time
|
||||
time.sleep(0.5)
|
||||
|
||||
except Exception as e:
|
||||
# 发生异常时,更新状态为增量更新失败(3)
|
||||
try:
|
||||
with engine.connect() as conn:
|
||||
update_status_sql = f"""
|
||||
UPDATE stock_metadata
|
||||
SET status = 3
|
||||
WHERE ts_code = '{ts_code}'
|
||||
"""
|
||||
conn.execute(text(update_status_sql))
|
||||
conn.commit()
|
||||
except Exception:
|
||||
pass # 更新状态失败时,忽略继续执行
|
||||
|
||||
results['failed_count'] += 1
|
||||
print(f"[{process_id}] 股票 {ts_code} 增量更新失败: {str(e)}")
|
||||
|
||||
results['processed_count'] += 1
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def perform_full_update(start_year=2020, processes=4, resume=False):
|
||||
"""
|
||||
执行全量更新:从指定年份开始到今天的所有数据
|
||||
|
||||
Args:
|
||||
start_year: 开始年份,默认为2020年
|
||||
processes: 进程数量,默认为4
|
||||
resume: 是否只更新之前全量更新失败的股票
|
||||
"""
|
||||
update_type = "续传全量更新" if resume else "全量更新"
|
||||
print(f"开始执行{update_type} (从{start_year}年至今),使用 {processes} 个进程...")
|
||||
start_time = datetime.now()
|
||||
|
||||
# 加载配置
|
||||
config = load_config()
|
||||
|
||||
# 创建数据库引擎
|
||||
engine = create_engine_from_config(config)
|
||||
|
||||
# 设置Tushare API
|
||||
ts.set_token(config['tushare_token'])
|
||||
pro = ts.pro_api()
|
||||
|
||||
# 设置数据获取的日期范围
|
||||
end_date = datetime.now().strftime('%Y%m%d')
|
||||
start_date = f"{start_year}0101" # 从指定年份的1月1日开始
|
||||
print(f"数据范围: {start_date} 至 {end_date}")
|
||||
|
||||
# 从元数据表获取股票代码
|
||||
with engine.connect() as conn:
|
||||
# 根据resume参数决定查询条件
|
||||
if resume:
|
||||
# 只查询全量更新失败的股票
|
||||
query = """
|
||||
SELECT ts_code
|
||||
FROM stock_metadata
|
||||
WHERE list_status = 'L' AND status = 2
|
||||
"""
|
||||
else:
|
||||
# 查询所有已上市的股票
|
||||
query = """
|
||||
SELECT ts_code
|
||||
FROM stock_metadata
|
||||
WHERE list_status = 'L'
|
||||
"""
|
||||
|
||||
result = conn.execute(text(query))
|
||||
stock_codes = [row[0] for row in result]
|
||||
|
||||
if not stock_codes:
|
||||
if resume:
|
||||
print("没有找到需要续传的股票")
|
||||
else:
|
||||
print("没有找到需要更新的股票,请先确保元数据表已初始化")
|
||||
return
|
||||
|
||||
print(f"共找到 {len(stock_codes)} 只股票需要{update_type}")
|
||||
|
||||
# 当前更新时间
|
||||
update_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||
|
||||
# 将股票列表分成多个批次
|
||||
batches = split_list_into_chunks(stock_codes, processes)
|
||||
|
||||
# 创建进程池
|
||||
with Pool(processes=processes) as pool:
|
||||
# 使用partial函数固定部分参数
|
||||
process_func = partial(
|
||||
process_stock_batch,
|
||||
config=config,
|
||||
update_mode='full',
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
update_time=update_time
|
||||
)
|
||||
|
||||
# 启动多进程处理并收集结果
|
||||
results = pool.map(process_func, batches)
|
||||
|
||||
# 汇总所有进程的结果
|
||||
success_count = sum(r['success_count'] for r in results)
|
||||
failed_count = sum(r['failed_count'] for r in results)
|
||||
|
||||
end_time = datetime.now()
|
||||
duration = (end_time - start_time).total_seconds() / 60 # 转换为分钟
|
||||
|
||||
print(f"\n{update_type}完成!")
|
||||
print(f"总耗时: {duration:.2f} 分钟")
|
||||
print(f"成功更新: {success_count} 只股票")
|
||||
print(f"更新失败: {failed_count} 只股票")
|
||||
print(f"总计: {len(stock_codes)} 只股票")
|
||||
|
||||
|
||||
def update_metadata(engine, ts_code):
|
||||
"""
|
||||
更新股票元数据中的统计信息,从股票表中直接获取数据
|
||||
|
||||
Args:
|
||||
engine: 数据库连接引擎
|
||||
ts_code: 股票代码(带后缀,如000001.SZ)
|
||||
|
||||
Returns:
|
||||
bool: 更新成功返回True,否则返回False
|
||||
"""
|
||||
try:
|
||||
# 提取股票代码(不含后缀)
|
||||
symbol = ts_code.split('.')[0]
|
||||
|
||||
# 查询股票数据统计信息
|
||||
stats_query = f"""
|
||||
SELECT
|
||||
@ -337,7 +665,6 @@ def update_metadata(engine, ts_code):
|
||||
COUNT(*) as record_count
|
||||
FROM `{symbol}`
|
||||
"""
|
||||
|
||||
# 查询最新交易日数据
|
||||
latest_data_query = f"""
|
||||
SELECT
|
||||
@ -349,14 +676,12 @@ def update_metadata(engine, ts_code):
|
||||
WHERE trade_date = (SELECT MAX(trade_date) FROM `{symbol}`)
|
||||
LIMIT 1
|
||||
"""
|
||||
|
||||
# 执行查询
|
||||
with engine.connect() as conn:
|
||||
# 获取统计信息
|
||||
stats_result = conn.execute(text(stats_query)).fetchone()
|
||||
if not stats_result:
|
||||
print(f"警告:{ts_code}没有数据,将更新元数据状态为异常")
|
||||
# 更新状态为异常
|
||||
# 更新状态为异常(2)
|
||||
update_empty_sql = f"""
|
||||
UPDATE stock_metadata
|
||||
SET last_full_update = NOW(),
|
||||
@ -400,7 +725,6 @@ def update_metadata(engine, ts_code):
|
||||
update_sql = f"""
|
||||
UPDATE stock_metadata
|
||||
SET
|
||||
last_full_update = NOW(),
|
||||
data_start_date = CASE WHEN '{data_start_date}' = 'NULL' THEN NULL ELSE '{data_start_date}' END,
|
||||
data_end_date = CASE WHEN '{data_end_date}' = 'NULL' THEN NULL ELSE '{data_end_date}' END,
|
||||
record_count = {record_count},
|
||||
@ -412,175 +736,37 @@ def update_metadata(engine, ts_code):
|
||||
latest_total_mv = CASE WHEN '{latest_total_mv}' = 'NULL' THEN NULL ELSE {latest_total_mv} END
|
||||
WHERE ts_code = '{ts_code}'
|
||||
"""
|
||||
|
||||
# 执行更新
|
||||
conn.execute(text(update_sql))
|
||||
conn.commit()
|
||||
|
||||
print(f"已更新({ts_code})的元数据信息")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"更新({ts_code})元数据时出错: {str(e)}")
|
||||
|
||||
# 尝试将状态更新为异常
|
||||
# 尝试将状态更新为元数据更新失败(-1)
|
||||
try:
|
||||
with engine.connect() as conn:
|
||||
error_update_sql = f"""
|
||||
UPDATE stock_metadata
|
||||
SET status = 2,
|
||||
last_full_update = NOW()
|
||||
SET status = -1
|
||||
WHERE ts_code = '{ts_code}'
|
||||
"""
|
||||
conn.execute(text(error_update_sql))
|
||||
conn.commit()
|
||||
except Exception as inner_e:
|
||||
print(f"更新({ts_code})状态为异常时出错: {str(inner_e)}")
|
||||
except Exception:
|
||||
pass # 更新状态失败时,忽略继续执行
|
||||
|
||||
print(f"更新({ts_code})元数据时出错: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
def process_stock(engine, pro, ts_code, start_date, end_date):
|
||||
"""处理单个股票的完整流程"""
|
||||
# 创建该股票的表
|
||||
create_stock_table(engine, ts_code)
|
||||
|
||||
# 下载该股票的数据
|
||||
stock_data = download_stock_data(pro, ts_code, start_date, end_date)
|
||||
|
||||
# 保存股票数据
|
||||
if not stock_data.empty:
|
||||
save_stock_data(stock_data, engine, ts_code)
|
||||
|
||||
update_metadata(engine, ts_code)
|
||||
else:
|
||||
print(f"警告:({ts_code})没有获取到数据")
|
||||
|
||||
|
||||
def perform_full_update(start_year=2020):
|
||||
"""
|
||||
执行全量更新:从指定年份开始到今天的所有数据
|
||||
|
||||
Args:
|
||||
start_year: 开始年份,默认为2020年
|
||||
"""
|
||||
print(f"开始执行全量更新 (从{start_year}年至今)...")
|
||||
start_time = datetime.now()
|
||||
|
||||
# 加载配置
|
||||
config = load_config()
|
||||
|
||||
# 创建数据库引擎
|
||||
engine = create_engine_from_config(config)
|
||||
|
||||
# 设置Tushare API
|
||||
ts.set_token(config['tushare_token'])
|
||||
pro = ts.pro_api()
|
||||
|
||||
# 设置数据获取的日期范围
|
||||
end_date = datetime.now().strftime('%Y%m%d')
|
||||
start_date = f"{start_year}0101" # 从指定年份的1月1日开始
|
||||
|
||||
print(f"数据范围: {start_date} 至 {end_date}")
|
||||
|
||||
# 从元数据表获取所有股票代码
|
||||
with engine.connect() as conn:
|
||||
# 查询所有需要更新的股票代码
|
||||
query = """
|
||||
SELECT ts_code
|
||||
FROM stock_metadata
|
||||
WHERE list_status = 'L' -- 只选择已上市的股票
|
||||
"""
|
||||
result = conn.execute(text(query))
|
||||
stock_codes = [row[0] for row in result]
|
||||
|
||||
if not stock_codes:
|
||||
print("没有找到需要更新的股票,请先确保元数据表已初始化")
|
||||
return
|
||||
|
||||
print(f"共找到 {len(stock_codes)} 只股票需要更新")
|
||||
|
||||
# 记录成功和失败的股票数量
|
||||
success_count = 0
|
||||
failed_count = 0
|
||||
skipped_count = 0
|
||||
|
||||
# 当前更新时间
|
||||
update_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||
|
||||
# 逐一处理每只股票
|
||||
total_stocks = len(stock_codes)
|
||||
for index, ts_code in enumerate(stock_codes):
|
||||
try:
|
||||
print(f"[{index + 1}/{total_stocks}] 正在全量更新股票: {ts_code}")
|
||||
|
||||
# 检查股票表是否存在,不存在则创建
|
||||
symbol = ts_code.split('.')[0]
|
||||
with engine.connect() as conn:
|
||||
table_exists_query = f"""
|
||||
SELECT COUNT(*)
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = '{config['mysql']['database']}'
|
||||
AND table_name = '{symbol}'
|
||||
"""
|
||||
table_exists = conn.execute(text(table_exists_query)).scalar() > 0
|
||||
|
||||
if not table_exists:
|
||||
create_stock_table(engine, ts_code)
|
||||
|
||||
# 下载股票数据
|
||||
stock_data = download_stock_data(pro, ts_code, start_date, end_date)
|
||||
|
||||
# 如果下载成功,保存数据并更新元数据
|
||||
if not stock_data.empty:
|
||||
# 使用replace模式保存数据(全量替换)
|
||||
records_saved = save_stock_data(stock_data, engine, ts_code)
|
||||
|
||||
# 更新元数据统计信息
|
||||
update_metadata(engine, ts_code)
|
||||
|
||||
# 只更新全量更新时间戳
|
||||
with engine.connect() as conn:
|
||||
update_status_sql = f"""
|
||||
UPDATE stock_metadata
|
||||
SET last_full_update = '{update_time}'
|
||||
WHERE ts_code = '{ts_code}'
|
||||
"""
|
||||
conn.execute(text(update_status_sql))
|
||||
conn.commit()
|
||||
|
||||
success_count += 1
|
||||
print(f"成功更新 {ts_code},保存了 {records_saved} 条记录")
|
||||
else:
|
||||
failed_count += 1
|
||||
print(f"警告:{ts_code} 没有获取到数据")
|
||||
|
||||
# 防止API限流,每次请求后短暂休息
|
||||
import time
|
||||
time.sleep(0.5)
|
||||
|
||||
except Exception as e:
|
||||
failed_count += 1
|
||||
print(f"处理股票 {ts_code} 时出错: {str(e)}")
|
||||
|
||||
end_time = datetime.now()
|
||||
duration = (end_time - start_time).total_seconds() / 60 # 转换为分钟
|
||||
|
||||
print("\n全量更新完成!")
|
||||
print(f"总耗时: {duration:.2f} 分钟")
|
||||
print(f"成功更新: {success_count} 只股票")
|
||||
print(f"更新失败: {failed_count} 只股票")
|
||||
print(f"总计: {total_stocks} 只股票")
|
||||
|
||||
|
||||
def perform_incremental_update():
|
||||
def perform_incremental_update(processes=4):
|
||||
"""
|
||||
执行增量更新:从上次增量或全量更新日期到今天的数据
|
||||
- 使用最近的更新日期(增量或全量)作为起点
|
||||
- 如果表不存在,跳过该股票
|
||||
- 只更新增量更新时间字段
|
||||
|
||||
Args:
|
||||
processes: 进程数量,默认为4
|
||||
"""
|
||||
print("开始执行增量更新...")
|
||||
print(f"开始执行增量更新,使用 {processes} 个进程...")
|
||||
start_time = datetime.now()
|
||||
|
||||
# 加载配置
|
||||
@ -600,18 +786,16 @@ def perform_incremental_update():
|
||||
stocks_to_update = []
|
||||
with engine.connect() as conn:
|
||||
query = """
|
||||
SELECT
|
||||
SELECT
|
||||
ts_code,
|
||||
last_incremental_update,
|
||||
last_full_update
|
||||
FROM
|
||||
stock_metadata
|
||||
WHERE
|
||||
list_status = 'L' AND
|
||||
status != 2 -- 排除状态异常的股票
|
||||
FROM
|
||||
stock_metadata
|
||||
WHERE
|
||||
list_status = 'L'
|
||||
"""
|
||||
result = conn.execute(text(query))
|
||||
|
||||
for row in result:
|
||||
ts_code = row[0]
|
||||
last_incr_update = row[1] # 上次增量更新日期
|
||||
@ -638,7 +822,6 @@ def perform_incremental_update():
|
||||
|
||||
# 如果起始日期大于等于结束日期,则跳过此股票
|
||||
if start_date >= end_date:
|
||||
print(f"股票 {ts_code} 数据已是最新,无需更新")
|
||||
continue
|
||||
|
||||
stocks_to_update.append({
|
||||
@ -652,105 +835,32 @@ def perform_incremental_update():
|
||||
|
||||
print(f"共找到 {len(stocks_to_update)} 只股票需要增量更新")
|
||||
|
||||
# 记录成功和失败的股票数量
|
||||
success_count = 0
|
||||
failed_count = 0
|
||||
no_new_data_count = 0
|
||||
skipped_count = 0
|
||||
|
||||
# 当前更新时间
|
||||
update_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||
|
||||
# 逐一处理每只股票
|
||||
total_stocks = len(stocks_to_update)
|
||||
for index, stock in enumerate(stocks_to_update):
|
||||
ts_code = stock['ts_code']
|
||||
start_date = stock['start_date']
|
||||
# 将股票列表分成多个批次
|
||||
batches = split_list_into_chunks(stocks_to_update, processes)
|
||||
|
||||
try:
|
||||
print(f"[{index + 1}/{total_stocks}] 正在增量更新股票: {ts_code} (从 {start_date} 到 {end_date})")
|
||||
# 创建进程池
|
||||
with Pool(processes=processes) as pool:
|
||||
# 使用partial函数固定部分参数
|
||||
process_func = partial(
|
||||
process_stock_batch,
|
||||
config=config,
|
||||
update_mode='incremental',
|
||||
start_date=None, # 增量模式下使用每个股票自己的起始日期
|
||||
end_date=end_date,
|
||||
update_time=update_time
|
||||
)
|
||||
|
||||
# 确保股票表存在
|
||||
symbol = ts_code.split('.')[0]
|
||||
with engine.connect() as conn:
|
||||
table_exists_query = f"""
|
||||
SELECT COUNT(*)
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = '{config['mysql']['database']}'
|
||||
AND table_name = '{symbol}'
|
||||
"""
|
||||
table_exists = conn.execute(text(table_exists_query)).scalar() > 0
|
||||
# 启动多进程处理并收集结果
|
||||
results = pool.map(process_func, batches)
|
||||
|
||||
if not table_exists:
|
||||
print(f"股票 {ts_code} 数据表不存在,跳过此股票")
|
||||
skipped_count += 1
|
||||
continue # 表不存在则跳过
|
||||
|
||||
# 下载增量数据
|
||||
new_data = download_stock_data(pro, ts_code, start_date, end_date)
|
||||
|
||||
if not new_data.empty:
|
||||
# 获取现有数据
|
||||
try:
|
||||
with engine.connect() as conn:
|
||||
existing_data = pd.read_sql(f"SELECT * FROM `{symbol}`", conn)
|
||||
except Exception as e:
|
||||
print(f"读取现有数据失败: {str(e)},跳过此股票")
|
||||
skipped_count += 1
|
||||
continue
|
||||
|
||||
if not existing_data.empty:
|
||||
# 转换日期列为相同的格式以便合并
|
||||
existing_data['trade_date'] = pd.to_datetime(existing_data['trade_date'])
|
||||
new_data['trade_date'] = pd.to_datetime(new_data['trade_date'])
|
||||
|
||||
# 删除可能重复的日期记录
|
||||
existing_dates = set(existing_data['trade_date'])
|
||||
new_data = new_data[~new_data['trade_date'].isin(existing_dates)]
|
||||
|
||||
if new_data.empty:
|
||||
print(f"股票 {ts_code} 没有新数据需要更新")
|
||||
no_new_data_count += 1
|
||||
continue
|
||||
|
||||
# 合并数据
|
||||
combined_data = pd.concat([existing_data, new_data], ignore_index=True)
|
||||
|
||||
# 按日期排序
|
||||
combined_data = combined_data.sort_values('trade_date')
|
||||
|
||||
# 保存合并后的数据
|
||||
records_saved = save_stock_data(combined_data, engine, ts_code)
|
||||
else:
|
||||
# 如果表存在但为空,直接保存新数据
|
||||
records_saved = save_stock_data(new_data, engine, ts_code)
|
||||
|
||||
# 更新元数据统计信息
|
||||
update_metadata(engine, ts_code)
|
||||
|
||||
# 只更新增量更新时间戳
|
||||
with engine.connect() as conn:
|
||||
update_status_sql = f"""
|
||||
UPDATE stock_metadata
|
||||
SET last_incremental_update = '{update_time}'
|
||||
WHERE ts_code = '{ts_code}'
|
||||
"""
|
||||
conn.execute(text(update_status_sql))
|
||||
conn.commit()
|
||||
|
||||
success_count += 1
|
||||
print(f"成功增量更新 {ts_code},新增 {len(new_data)} 条记录")
|
||||
else:
|
||||
print(f"股票 {ts_code} 在指定时间范围内没有新数据")
|
||||
no_new_data_count += 1
|
||||
|
||||
# 防止API限流,每次请求后短暂休息
|
||||
import time
|
||||
time.sleep(0.5)
|
||||
|
||||
except Exception as e:
|
||||
failed_count += 1
|
||||
print(f"增量更新股票 {ts_code} 时出错: {str(e)}")
|
||||
# 汇总所有进程的结果
|
||||
success_count = sum(r['success_count'] for r in results)
|
||||
failed_count = sum(r['failed_count'] for r in results)
|
||||
no_new_data_count = sum(r['no_new_data_count'] for r in results)
|
||||
skipped_count = sum(r['skipped_count'] for r in results)
|
||||
|
||||
end_time = datetime.now()
|
||||
duration = (end_time - start_time).total_seconds() / 60 # 转换为分钟
|
||||
@ -761,7 +871,7 @@ def perform_incremental_update():
|
||||
print(f"无新数据: {no_new_data_count} 只股票")
|
||||
print(f"跳过股票: {skipped_count} 只股票")
|
||||
print(f"更新失败: {failed_count} 只股票")
|
||||
print(f"总计: {total_stocks} 只股票")
|
||||
print(f"总计: {len(stocks_to_update)} 只股票")
|
||||
|
||||
|
||||
def main():
|
||||
@ -774,6 +884,9 @@ def main():
|
||||
parser.add_argument('--mode', choices=['full', 'incremental', 'both'],
|
||||
default='full', help='更新模式: full=全量更新, incremental=增量更新, both=两者都执行')
|
||||
parser.add_argument('--year', type=int, default=2020, help='全量更新的起始年份')
|
||||
parser.add_argument('--processes', type=int, default=8, help='使用的进程数量')
|
||||
parser.add_argument('--resume', action='store_true', help='仅更新全量更新失败的股票,即status=2部分')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 如果需要初始化元数据
|
||||
@ -795,11 +908,11 @@ def main():
|
||||
|
||||
# 根据选择的模式执行更新
|
||||
if args.mode == 'full' or args.mode == 'both':
|
||||
perform_full_update(args.year)
|
||||
perform_full_update(args.year, args.processes, args.resume)
|
||||
|
||||
if args.mode == 'incremental' or args.mode == 'both':
|
||||
perform_incremental_update()
|
||||
perform_incremental_update(args.processes)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
main()
|
||||
|
Loading…
Reference in New Issue
Block a user