diff --git a/data_downloader.py b/data_downloader.py index 43c46b6..dd81a6e 100644 --- a/data_downloader.py +++ b/data_downloader.py @@ -2,6 +2,9 @@ import pandas as pd import tushare as ts import numpy as np import os +import multiprocessing +from multiprocessing import Pool +from functools import partial from sqlalchemy import create_engine, text from datetime import datetime, timedelta @@ -314,21 +317,346 @@ def save_stock_data(df, engine, ts_code, if_exists='replace'): return 0 +def split_list_into_chunks(data_list, num_chunks): + """ + 将列表分割成指定数量的块 + + Args: + data_list: 要分割的列表 + num_chunks: 块的数量 + + Returns: + list: 包含多个子列表的列表 + """ + avg = len(data_list) / float(num_chunks) + chunks = [] + last = 0.0 + + while last < len(data_list): + chunks.append(data_list[int(last):int(last + avg)]) + last += avg + + return chunks + + +def process_stock_batch(batch, config, update_mode, start_date, end_date, update_time): + """ + 处理一批股票的工作函数(在单独进程中运行) + + Args: + batch: 要处理的股票列表 + config: 配置信息 + update_mode: 更新模式 'full' 或 'incremental' + start_date: 开始日期 + end_date: 结束日期 + update_time: 当前更新时间戳 + + Returns: + dict: 包含处理结果统计信息的字典 + """ + # 为当前进程创建新的数据库连接和API连接 + engine = create_engine_from_config(config) + ts.set_token(config['tushare_token']) + pro = ts.pro_api() + + # 结果统计 + results = { + 'success_count': 0, + 'failed_count': 0, + 'no_new_data_count': 0, + 'skipped_count': 0, + 'processed_count': 0 + } + + # 获取当前进程ID + process_id = multiprocessing.current_process().name + total_stocks = len(batch) + + for i, stock_info in enumerate(batch): + # 根据更新模式获取不同的参数 + if update_mode == 'full': + ts_code = stock_info + try: + # 简化日志,只显示进度 + print(f"[{process_id}] 进度 {i + 1}/{total_stocks}: 全量更新 {ts_code}") + + # 检查股票表是否存在,不存在则创建 + symbol = ts_code.split('.')[0] + with engine.connect() as conn: + table_exists_query = f""" + SELECT COUNT(*) + FROM information_schema.tables + WHERE table_schema = '{config['mysql']['database']}' + AND table_name = '{symbol}' + """ + table_exists = conn.execute(text(table_exists_query)).scalar() > 0 + + if not table_exists: + create_stock_table(engine, ts_code) + + # 下载股票数据 + stock_data = download_stock_data(pro, ts_code, start_date, end_date) + + # 如果下载成功,保存数据并更新元数据 + if not stock_data.empty: + # 使用replace模式保存数据(全量替换) + records_saved = save_stock_data(stock_data, engine, ts_code) + # 更新元数据统计信息 + update_metadata(engine, ts_code) + # 更新全量更新时间戳和状态为正常(1) + with engine.connect() as conn: + update_status_sql = f""" + UPDATE stock_metadata + SET last_full_update = '{update_time}', + status = 1 + WHERE ts_code = '{ts_code}' + """ + conn.execute(text(update_status_sql)) + conn.commit() + results['success_count'] += 1 + else: + # 更新状态为全量更新失败(2) + with engine.connect() as conn: + update_status_sql = f""" + UPDATE stock_metadata + SET status = 2 + WHERE ts_code = '{ts_code}' + """ + conn.execute(text(update_status_sql)) + conn.commit() + results['failed_count'] += 1 + + # 防止API限流 + import time + time.sleep(0.5) + + except Exception as e: + # 发生异常时,更新状态为全量更新失败(2) + try: + with engine.connect() as conn: + update_status_sql = f""" + UPDATE stock_metadata + SET status = 2 + WHERE ts_code = '{ts_code}' + """ + conn.execute(text(update_status_sql)) + conn.commit() + except Exception: + pass # 更新状态失败时,忽略继续执行 + + results['failed_count'] += 1 + print(f"[{process_id}] 股票 {ts_code} 更新失败: {str(e)}") + + elif update_mode == 'incremental': + ts_code = stock_info['ts_code'] + start_date = stock_info['start_date'] + + try: + # 简化日志,只显示进度 + print(f"[{process_id}] 进度 {i + 1}/{total_stocks}: 增量更新 {ts_code}") + + # 确保股票表存在 + symbol = ts_code.split('.')[0] + with engine.connect() as conn: + table_exists_query = f""" + SELECT COUNT(*) + FROM information_schema.tables + WHERE table_schema = '{config['mysql']['database']}' + AND table_name = '{symbol}' + """ + table_exists = conn.execute(text(table_exists_query)).scalar() > 0 + + if not table_exists: + results['skipped_count'] += 1 + continue # 表不存在则跳过 + + # 下载增量数据 + new_data = download_stock_data(pro, ts_code, start_date, end_date) + + if not new_data.empty: + # 获取现有数据 + try: + with engine.connect() as conn: + existing_data = pd.read_sql(f"SELECT * FROM `{symbol}`", conn) + except Exception as e: + results['skipped_count'] += 1 + continue + + if not existing_data.empty: + # 转换日期列为相同的格式以便合并 + existing_data['trade_date'] = pd.to_datetime(existing_data['trade_date']) + new_data['trade_date'] = pd.to_datetime(new_data['trade_date']) + + # 删除可能重复的日期记录 + existing_dates = set(existing_data['trade_date']) + new_data = new_data[~new_data['trade_date'].isin(existing_dates)] + + if new_data.empty: + results['no_new_data_count'] += 1 + continue + + # 合并数据 + combined_data = pd.concat([existing_data, new_data], ignore_index=True) + + # 按日期排序 + combined_data = combined_data.sort_values('trade_date') + + # 保存合并后的数据 + records_saved = save_stock_data(combined_data, engine, ts_code) + else: + # 如果表存在但为空,直接保存新数据 + records_saved = save_stock_data(new_data, engine, ts_code) + + # 更新元数据统计信息 + update_metadata(engine, ts_code) + + # 更新增量更新时间戳和状态为正常(1) + with engine.connect() as conn: + update_status_sql = f""" + UPDATE stock_metadata + SET last_incremental_update = '{update_time}', + status = 1 + WHERE ts_code = '{ts_code}' + """ + conn.execute(text(update_status_sql)) + conn.commit() + + results['success_count'] += 1 + else: + # 无新数据 + results['no_new_data_count'] += 1 + + # 防止API限流 + import time + time.sleep(0.5) + + except Exception as e: + # 发生异常时,更新状态为增量更新失败(3) + try: + with engine.connect() as conn: + update_status_sql = f""" + UPDATE stock_metadata + SET status = 3 + WHERE ts_code = '{ts_code}' + """ + conn.execute(text(update_status_sql)) + conn.commit() + except Exception: + pass # 更新状态失败时,忽略继续执行 + + results['failed_count'] += 1 + print(f"[{process_id}] 股票 {ts_code} 增量更新失败: {str(e)}") + + results['processed_count'] += 1 + + return results + + +def perform_full_update(start_year=2020, processes=4, resume=False): + """ + 执行全量更新:从指定年份开始到今天的所有数据 + + Args: + start_year: 开始年份,默认为2020年 + processes: 进程数量,默认为4 + resume: 是否只更新之前全量更新失败的股票 + """ + update_type = "续传全量更新" if resume else "全量更新" + print(f"开始执行{update_type} (从{start_year}年至今),使用 {processes} 个进程...") + start_time = datetime.now() + + # 加载配置 + config = load_config() + + # 创建数据库引擎 + engine = create_engine_from_config(config) + + # 设置Tushare API + ts.set_token(config['tushare_token']) + pro = ts.pro_api() + + # 设置数据获取的日期范围 + end_date = datetime.now().strftime('%Y%m%d') + start_date = f"{start_year}0101" # 从指定年份的1月1日开始 + print(f"数据范围: {start_date} 至 {end_date}") + + # 从元数据表获取股票代码 + with engine.connect() as conn: + # 根据resume参数决定查询条件 + if resume: + # 只查询全量更新失败的股票 + query = """ + SELECT ts_code + FROM stock_metadata + WHERE list_status = 'L' AND status = 2 + """ + else: + # 查询所有已上市的股票 + query = """ + SELECT ts_code + FROM stock_metadata + WHERE list_status = 'L' + """ + + result = conn.execute(text(query)) + stock_codes = [row[0] for row in result] + + if not stock_codes: + if resume: + print("没有找到需要续传的股票") + else: + print("没有找到需要更新的股票,请先确保元数据表已初始化") + return + + print(f"共找到 {len(stock_codes)} 只股票需要{update_type}") + + # 当前更新时间 + update_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S') + + # 将股票列表分成多个批次 + batches = split_list_into_chunks(stock_codes, processes) + + # 创建进程池 + with Pool(processes=processes) as pool: + # 使用partial函数固定部分参数 + process_func = partial( + process_stock_batch, + config=config, + update_mode='full', + start_date=start_date, + end_date=end_date, + update_time=update_time + ) + + # 启动多进程处理并收集结果 + results = pool.map(process_func, batches) + + # 汇总所有进程的结果 + success_count = sum(r['success_count'] for r in results) + failed_count = sum(r['failed_count'] for r in results) + + end_time = datetime.now() + duration = (end_time - start_time).total_seconds() / 60 # 转换为分钟 + + print(f"\n{update_type}完成!") + print(f"总耗时: {duration:.2f} 分钟") + print(f"成功更新: {success_count} 只股票") + print(f"更新失败: {failed_count} 只股票") + print(f"总计: {len(stock_codes)} 只股票") + + def update_metadata(engine, ts_code): """ 更新股票元数据中的统计信息,从股票表中直接获取数据 - Args: engine: 数据库连接引擎 ts_code: 股票代码(带后缀,如000001.SZ) - Returns: bool: 更新成功返回True,否则返回False """ try: # 提取股票代码(不含后缀) symbol = ts_code.split('.')[0] - # 查询股票数据统计信息 stats_query = f""" SELECT @@ -337,7 +665,6 @@ def update_metadata(engine, ts_code): COUNT(*) as record_count FROM `{symbol}` """ - # 查询最新交易日数据 latest_data_query = f""" SELECT @@ -349,14 +676,12 @@ def update_metadata(engine, ts_code): WHERE trade_date = (SELECT MAX(trade_date) FROM `{symbol}`) LIMIT 1 """ - # 执行查询 with engine.connect() as conn: # 获取统计信息 stats_result = conn.execute(text(stats_query)).fetchone() if not stats_result: - print(f"警告:{ts_code}没有数据,将更新元数据状态为异常") - # 更新状态为异常 + # 更新状态为异常(2) update_empty_sql = f""" UPDATE stock_metadata SET last_full_update = NOW(), @@ -400,7 +725,6 @@ def update_metadata(engine, ts_code): update_sql = f""" UPDATE stock_metadata SET - last_full_update = NOW(), data_start_date = CASE WHEN '{data_start_date}' = 'NULL' THEN NULL ELSE '{data_start_date}' END, data_end_date = CASE WHEN '{data_end_date}' = 'NULL' THEN NULL ELSE '{data_end_date}' END, record_count = {record_count}, @@ -412,175 +736,37 @@ def update_metadata(engine, ts_code): latest_total_mv = CASE WHEN '{latest_total_mv}' = 'NULL' THEN NULL ELSE {latest_total_mv} END WHERE ts_code = '{ts_code}' """ - # 执行更新 conn.execute(text(update_sql)) conn.commit() - print(f"已更新({ts_code})的元数据信息") return True - except Exception as e: - print(f"更新({ts_code})元数据时出错: {str(e)}") - - # 尝试将状态更新为异常 + # 尝试将状态更新为元数据更新失败(-1) try: with engine.connect() as conn: error_update_sql = f""" UPDATE stock_metadata - SET status = 2, - last_full_update = NOW() + SET status = -1 WHERE ts_code = '{ts_code}' """ conn.execute(text(error_update_sql)) conn.commit() - except Exception as inner_e: - print(f"更新({ts_code})状态为异常时出错: {str(inner_e)}") + except Exception: + pass # 更新状态失败时,忽略继续执行 + print(f"更新({ts_code})元数据时出错: {str(e)}") return False -def process_stock(engine, pro, ts_code, start_date, end_date): - """处理单个股票的完整流程""" - # 创建该股票的表 - create_stock_table(engine, ts_code) - - # 下载该股票的数据 - stock_data = download_stock_data(pro, ts_code, start_date, end_date) - - # 保存股票数据 - if not stock_data.empty: - save_stock_data(stock_data, engine, ts_code) - - update_metadata(engine, ts_code) - else: - print(f"警告:({ts_code})没有获取到数据") - - -def perform_full_update(start_year=2020): - """ - 执行全量更新:从指定年份开始到今天的所有数据 - - Args: - start_year: 开始年份,默认为2020年 - """ - print(f"开始执行全量更新 (从{start_year}年至今)...") - start_time = datetime.now() - - # 加载配置 - config = load_config() - - # 创建数据库引擎 - engine = create_engine_from_config(config) - - # 设置Tushare API - ts.set_token(config['tushare_token']) - pro = ts.pro_api() - - # 设置数据获取的日期范围 - end_date = datetime.now().strftime('%Y%m%d') - start_date = f"{start_year}0101" # 从指定年份的1月1日开始 - - print(f"数据范围: {start_date} 至 {end_date}") - - # 从元数据表获取所有股票代码 - with engine.connect() as conn: - # 查询所有需要更新的股票代码 - query = """ - SELECT ts_code - FROM stock_metadata - WHERE list_status = 'L' -- 只选择已上市的股票 - """ - result = conn.execute(text(query)) - stock_codes = [row[0] for row in result] - - if not stock_codes: - print("没有找到需要更新的股票,请先确保元数据表已初始化") - return - - print(f"共找到 {len(stock_codes)} 只股票需要更新") - - # 记录成功和失败的股票数量 - success_count = 0 - failed_count = 0 - skipped_count = 0 - - # 当前更新时间 - update_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S') - - # 逐一处理每只股票 - total_stocks = len(stock_codes) - for index, ts_code in enumerate(stock_codes): - try: - print(f"[{index + 1}/{total_stocks}] 正在全量更新股票: {ts_code}") - - # 检查股票表是否存在,不存在则创建 - symbol = ts_code.split('.')[0] - with engine.connect() as conn: - table_exists_query = f""" - SELECT COUNT(*) - FROM information_schema.tables - WHERE table_schema = '{config['mysql']['database']}' - AND table_name = '{symbol}' - """ - table_exists = conn.execute(text(table_exists_query)).scalar() > 0 - - if not table_exists: - create_stock_table(engine, ts_code) - - # 下载股票数据 - stock_data = download_stock_data(pro, ts_code, start_date, end_date) - - # 如果下载成功,保存数据并更新元数据 - if not stock_data.empty: - # 使用replace模式保存数据(全量替换) - records_saved = save_stock_data(stock_data, engine, ts_code) - - # 更新元数据统计信息 - update_metadata(engine, ts_code) - - # 只更新全量更新时间戳 - with engine.connect() as conn: - update_status_sql = f""" - UPDATE stock_metadata - SET last_full_update = '{update_time}' - WHERE ts_code = '{ts_code}' - """ - conn.execute(text(update_status_sql)) - conn.commit() - - success_count += 1 - print(f"成功更新 {ts_code},保存了 {records_saved} 条记录") - else: - failed_count += 1 - print(f"警告:{ts_code} 没有获取到数据") - - # 防止API限流,每次请求后短暂休息 - import time - time.sleep(0.5) - - except Exception as e: - failed_count += 1 - print(f"处理股票 {ts_code} 时出错: {str(e)}") - - end_time = datetime.now() - duration = (end_time - start_time).total_seconds() / 60 # 转换为分钟 - - print("\n全量更新完成!") - print(f"总耗时: {duration:.2f} 分钟") - print(f"成功更新: {success_count} 只股票") - print(f"更新失败: {failed_count} 只股票") - print(f"总计: {total_stocks} 只股票") - - -def perform_incremental_update(): +def perform_incremental_update(processes=4): """ 执行增量更新:从上次增量或全量更新日期到今天的数据 - - 使用最近的更新日期(增量或全量)作为起点 - - 如果表不存在,跳过该股票 - - 只更新增量更新时间字段 + + Args: + processes: 进程数量,默认为4 """ - print("开始执行增量更新...") + print(f"开始执行增量更新,使用 {processes} 个进程...") start_time = datetime.now() # 加载配置 @@ -600,18 +786,16 @@ def perform_incremental_update(): stocks_to_update = [] with engine.connect() as conn: query = """ - SELECT + SELECT ts_code, last_incremental_update, last_full_update - FROM - stock_metadata - WHERE - list_status = 'L' AND - status != 2 -- 排除状态异常的股票 + FROM + stock_metadata + WHERE + list_status = 'L' """ result = conn.execute(text(query)) - for row in result: ts_code = row[0] last_incr_update = row[1] # 上次增量更新日期 @@ -638,7 +822,6 @@ def perform_incremental_update(): # 如果起始日期大于等于结束日期,则跳过此股票 if start_date >= end_date: - print(f"股票 {ts_code} 数据已是最新,无需更新") continue stocks_to_update.append({ @@ -652,105 +835,32 @@ def perform_incremental_update(): print(f"共找到 {len(stocks_to_update)} 只股票需要增量更新") - # 记录成功和失败的股票数量 - success_count = 0 - failed_count = 0 - no_new_data_count = 0 - skipped_count = 0 - # 当前更新时间 update_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S') - # 逐一处理每只股票 - total_stocks = len(stocks_to_update) - for index, stock in enumerate(stocks_to_update): - ts_code = stock['ts_code'] - start_date = stock['start_date'] + # 将股票列表分成多个批次 + batches = split_list_into_chunks(stocks_to_update, processes) - try: - print(f"[{index + 1}/{total_stocks}] 正在增量更新股票: {ts_code} (从 {start_date} 到 {end_date})") + # 创建进程池 + with Pool(processes=processes) as pool: + # 使用partial函数固定部分参数 + process_func = partial( + process_stock_batch, + config=config, + update_mode='incremental', + start_date=None, # 增量模式下使用每个股票自己的起始日期 + end_date=end_date, + update_time=update_time + ) - # 确保股票表存在 - symbol = ts_code.split('.')[0] - with engine.connect() as conn: - table_exists_query = f""" - SELECT COUNT(*) - FROM information_schema.tables - WHERE table_schema = '{config['mysql']['database']}' - AND table_name = '{symbol}' - """ - table_exists = conn.execute(text(table_exists_query)).scalar() > 0 + # 启动多进程处理并收集结果 + results = pool.map(process_func, batches) - if not table_exists: - print(f"股票 {ts_code} 数据表不存在,跳过此股票") - skipped_count += 1 - continue # 表不存在则跳过 - - # 下载增量数据 - new_data = download_stock_data(pro, ts_code, start_date, end_date) - - if not new_data.empty: - # 获取现有数据 - try: - with engine.connect() as conn: - existing_data = pd.read_sql(f"SELECT * FROM `{symbol}`", conn) - except Exception as e: - print(f"读取现有数据失败: {str(e)},跳过此股票") - skipped_count += 1 - continue - - if not existing_data.empty: - # 转换日期列为相同的格式以便合并 - existing_data['trade_date'] = pd.to_datetime(existing_data['trade_date']) - new_data['trade_date'] = pd.to_datetime(new_data['trade_date']) - - # 删除可能重复的日期记录 - existing_dates = set(existing_data['trade_date']) - new_data = new_data[~new_data['trade_date'].isin(existing_dates)] - - if new_data.empty: - print(f"股票 {ts_code} 没有新数据需要更新") - no_new_data_count += 1 - continue - - # 合并数据 - combined_data = pd.concat([existing_data, new_data], ignore_index=True) - - # 按日期排序 - combined_data = combined_data.sort_values('trade_date') - - # 保存合并后的数据 - records_saved = save_stock_data(combined_data, engine, ts_code) - else: - # 如果表存在但为空,直接保存新数据 - records_saved = save_stock_data(new_data, engine, ts_code) - - # 更新元数据统计信息 - update_metadata(engine, ts_code) - - # 只更新增量更新时间戳 - with engine.connect() as conn: - update_status_sql = f""" - UPDATE stock_metadata - SET last_incremental_update = '{update_time}' - WHERE ts_code = '{ts_code}' - """ - conn.execute(text(update_status_sql)) - conn.commit() - - success_count += 1 - print(f"成功增量更新 {ts_code},新增 {len(new_data)} 条记录") - else: - print(f"股票 {ts_code} 在指定时间范围内没有新数据") - no_new_data_count += 1 - - # 防止API限流,每次请求后短暂休息 - import time - time.sleep(0.5) - - except Exception as e: - failed_count += 1 - print(f"增量更新股票 {ts_code} 时出错: {str(e)}") + # 汇总所有进程的结果 + success_count = sum(r['success_count'] for r in results) + failed_count = sum(r['failed_count'] for r in results) + no_new_data_count = sum(r['no_new_data_count'] for r in results) + skipped_count = sum(r['skipped_count'] for r in results) end_time = datetime.now() duration = (end_time - start_time).total_seconds() / 60 # 转换为分钟 @@ -761,7 +871,7 @@ def perform_incremental_update(): print(f"无新数据: {no_new_data_count} 只股票") print(f"跳过股票: {skipped_count} 只股票") print(f"更新失败: {failed_count} 只股票") - print(f"总计: {total_stocks} 只股票") + print(f"总计: {len(stocks_to_update)} 只股票") def main(): @@ -774,6 +884,9 @@ def main(): parser.add_argument('--mode', choices=['full', 'incremental', 'both'], default='full', help='更新模式: full=全量更新, incremental=增量更新, both=两者都执行') parser.add_argument('--year', type=int, default=2020, help='全量更新的起始年份') + parser.add_argument('--processes', type=int, default=8, help='使用的进程数量') + parser.add_argument('--resume', action='store_true', help='仅更新全量更新失败的股票,即status=2部分') + args = parser.parse_args() # 如果需要初始化元数据 @@ -795,11 +908,11 @@ def main(): # 根据选择的模式执行更新 if args.mode == 'full' or args.mode == 'both': - perform_full_update(args.year) + perform_full_update(args.year, args.processes, args.resume) if args.mode == 'incremental' or args.mode == 'both': - perform_incremental_update() + perform_incremental_update(args.processes) if __name__ == '__main__': - main() \ No newline at end of file + main()