import argparse import pandas as pd import tushare as ts import numpy as np import os import multiprocessing from multiprocessing import Pool from functools import partial from sqlalchemy import create_engine, text from datetime import datetime, timedelta from utils import load_config, create_engine_from_config def create_metadata_table(engine): """创建股票元数据表""" # 从sql文件读取创建表的SQL with open('sql/meta.sql', 'r', encoding='utf-8') as f: create_table_sql = f.read() with engine.connect() as conn: conn.execute(text(create_table_sql)) conn.commit() print("股票元数据表创建成功") def create_stock_table(engine, ts_code): """为指定股票创建数据表""" # 确保sql目录存在 os.makedirs('sql', exist_ok=True) # 读取模板SQL with open('sql/stock_table_template.sql', 'r', encoding='utf-8') as f: template_sql = f.read() # 替换模板中的symbol create_table_sql = template_sql.format(symbol=ts_code.split('.')[0]) # 执行创建表的SQL with engine.connect() as conn: conn.execute(text(create_table_sql)) conn.commit() print(f"股票({ts_code})数据表创建成功") def download_stock_metadata(pro): """下载股票基础信息""" df = pro.stock_basic(exchange='', list_status='L', fields='ts_code,symbol,name,area,industry,fullname,enname,cnspell,market,exchange,curr_type,list_status,list_date,delist_date,is_hs,act_name,act_ent_type') # 添加状态字段 df['status'] = 0 # 未初始化 df['remark'] = '自动导入' df['last_full_update'] = None df['last_incremental_update'] = None df['data_start_date'] = None df['data_end_date'] = None df['record_count'] = 0 df['latest_price'] = None df['latest_date'] = None df['latest_pe_ttm'] = None df['latest_pb'] = None df['latest_total_mv'] = None return df def save_metadata(df, engine): """保存元数据到数据库,只更新指定字段发生变化的记录,处理日期格式问题""" # 需要监控变化的基础元数据字段 metadata_fields = [ 'ts_code', 'symbol', 'name', 'area', 'industry', 'fullname', 'enname', 'cnspell', 'market', 'exchange', 'curr_type', 'list_status', 'list_date', 'delist_date', 'is_hs', 'act_name', 'act_ent_type' ] # 日期字段列表 date_fields = ['list_date', 'delist_date'] # 从数据库中读取现有的元数据 try: with engine.connect() as conn: # 只查询我们关心的字段 query = f"SELECT {', '.join(metadata_fields)} FROM stock_metadata" existing_df = pd.read_sql(query, conn) except Exception as e: print(f"读取现有元数据时出错: {str(e)}") existing_df = pd.DataFrame() # 表不存在时创建空DataFrame if existing_df.empty: # 如果表是空的,直接插入所有数据 result = df.to_sql('stock_metadata', engine, index=False, if_exists='append', chunksize=1000) print(f"成功保存{result}条股票元数据(新建)") return result else: # 数据预处理:处理日期格式 df_processed = df.copy() existing_processed = existing_df.copy() # 标准化日期格式进行比较 def normalize_date(date_str): if pd.isna(date_str) or date_str == '' or date_str == 'None' or date_str is None: return '' # 去除所有非数字字符 date_str = str(date_str).strip() digits_only = ''.join(c for c in date_str if c.isdigit()) # 如果是8位数字形式的日期,返回标准格式 if len(digits_only) == 8: return digits_only return '' # 对两个DataFrame应用日期标准化 for field in date_fields: if field in df_processed.columns and field in existing_processed.columns: df_processed[field] = df_processed[field].apply(normalize_date) existing_processed[field] = existing_processed[field].apply(normalize_date) # 对其他非日期字段进行标准化处理 for col in [f for f in metadata_fields if f not in date_fields and f != 'ts_code']: if col in df_processed.columns and col in existing_processed.columns: # 将两个DataFrame的相同列转换为相同的数据类型(字符串类型) df_processed[col] = df_processed[col].astype(str) existing_processed[col] = existing_processed[col].astype(str) # 处理NaN值 df_processed[col] = df_processed[col].fillna('').str.strip() existing_processed[col] = existing_processed[col].fillna('').str.strip() # 找出新增的记录 existing_ts_codes = set(existing_processed['ts_code']) new_records = df[~df['ts_code'].isin(existing_ts_codes)].copy() # 找出需要更新的记录 changed_records = [] unchanged_count = 0 # 对于每个已存在的ts_code,检查是否有变化 for ts_code in df_processed[df_processed['ts_code'].isin(existing_ts_codes)]['ts_code']: # 获取新旧记录(已经标准化后的) new_record = df_processed[df_processed['ts_code'] == ts_code].iloc[0] old_record = existing_processed[existing_processed['ts_code'] == ts_code].iloc[0] # 同时取原始记录用于更新操作 original_record = df[df['ts_code'] == ts_code].iloc[0] # 检查是否有变化 has_change = False changed_fields = [] for field in metadata_fields: if field == 'ts_code': # 跳过主键 continue new_val = new_record[field] old_val = old_record[field] if new_val != old_val: has_change = True changed_fields.append(field) print(f"发现变化 - {ts_code} 的 {field}: '{old_val}' -> '{new_val}'") if has_change: changed_records.append({ 'record': original_record, # 使用原始记录,保持原始格式 'changed_fields': changed_fields }) else: unchanged_count += 1 # 插入新记录 new_count = 0 if not new_records.empty: new_count = new_records.to_sql('stock_metadata', engine, index=False, if_exists='append', chunksize=1000) # 更新变化的记录 updated_count = 0 for change_info in changed_records: record = change_info['record'] fields = change_info['changed_fields'] # 构建更新语句,只更新变化的字段 fields_to_update = [] params = {'ts_code': record['ts_code']} for field in fields: fields_to_update.append(f"{field} = :{field}") params[field] = record[field] if fields_to_update: update_stmt = text(f""" UPDATE stock_metadata SET {', '.join(fields_to_update)} WHERE ts_code = :ts_code """) with engine.connect() as conn: conn.execute(update_stmt, params) updated_count += 1 conn.commit() print(f"元数据更新统计:") print(f" • 新增记录: {new_count}条") print(f" • 更新记录: {updated_count}条") print(f" • 无变化记录: {unchanged_count}条") print(f" • 总处理记录: {new_count + updated_count + unchanged_count}条") return new_count + updated_count def download_stock_data(pro, ts_code, start_date, end_date): """下载股票的所有类型数据并合并""" # 下载daily价格数据 daily_df = pro.daily(ts_code=ts_code, start_date=start_date, end_date=end_date) # 下载daily_basic数据 daily_basic_df = pro.daily_basic(ts_code=ts_code, start_date=start_date, end_date=end_date) # 下载moneyflow数据 moneyflow_df = pro.moneyflow(ts_code=ts_code, start_date=start_date, end_date=end_date) if daily_df.empty or daily_basic_df.empty or moneyflow_df.empty: print(f"警告:{ts_code}daily数据不全,无法合并") return pd.DataFrame() # 确保每个DataFrame都有trade_date列作为合并键 if 'trade_date' not in daily_df.columns: print(f"错误:{ts_code}的daily数据缺少trade_date列") return pd.DataFrame() # 为方便处理,确保所有日期列是字符串类型 daily_df['trade_date'] = daily_df['trade_date'].astype(str) # 通过merge而不是join合并数据,这样可以更好地控制列 result_df = daily_df # 合并daily_basic数据 if not daily_basic_df.empty: daily_basic_df['trade_date'] = daily_basic_df['trade_date'].astype(str) # 识别重叠的列(除了ts_code和trade_date) overlap_cols = list(set(result_df.columns) & set(daily_basic_df.columns) - {'ts_code', 'trade_date'}) # 从daily_basic中排除这些重叠列 daily_basic_df_filtered = daily_basic_df.drop(columns=overlap_cols) # 合并数据 result_df = pd.merge( result_df, daily_basic_df_filtered, on=['ts_code', 'trade_date'], how='left' ) # 合并moneyflow数据 if not moneyflow_df.empty: moneyflow_df['trade_date'] = moneyflow_df['trade_date'].astype(str) # 识别重叠的列(除了ts_code和trade_date) overlap_cols = list(set(result_df.columns) & set(moneyflow_df.columns) - {'ts_code', 'trade_date'}) # 从moneyflow中排除这些重叠列 moneyflow_df_filtered = moneyflow_df.drop(columns=overlap_cols) # 合并数据 result_df = pd.merge( result_df, moneyflow_df_filtered, on=['ts_code', 'trade_date'], how='left' ) # 将trade_date转换为datetime格式 result_df['trade_date'] = pd.to_datetime(result_df['trade_date']) return result_df def save_stock_data(df, engine, ts_code, if_exists='replace'): """ 保存股票数据到对应的表中 Args: df: 股票数据DataFrame engine: 数据库引擎 ts_code: 股票代码 if_exists: 如果表已存在,处理方式。'replace': 替换现有表,'append': 附加到现有表 Returns: int: 保存的记录数量 """ if df.empty: print(f"警告:{ts_code}没有数据可保存") return 0 # 删除ts_code列,因为表中不需要 if 'ts_code' in df.columns: df = df.drop(columns=['ts_code']) # 使用指定的模式保存数据 try: symbol = ts_code.split('.')[0] result = df.to_sql(f'{symbol}', engine, index=False, if_exists=if_exists, chunksize=1000) print(f"成功保存{result}条{ts_code}股票数据 (模式: {if_exists})") return result except Exception as e: print(f"保存{ts_code}数据时出错: {str(e)}") return 0 def split_list_into_chunks(data_list, num_chunks): """ 将列表分割成指定数量的块 Args: data_list: 要分割的列表 num_chunks: 块的数量 Returns: list: 包含多个子列表的列表 """ avg = len(data_list) / float(num_chunks) chunks = [] last = 0.0 while last < len(data_list): chunks.append(data_list[int(last):int(last + avg)]) last += avg return chunks def process_stock_batch(batch, config, update_mode, start_date, end_date): """ 处理一批股票的工作函数(在单独进程中运行) Args: batch: 要处理的股票列表 config: 配置信息 update_mode: 更新模式 'full' 或 'incremental' start_date: 开始日期 end_date: 结束日期 Returns: dict: 包含处理结果统计信息的字典 """ # 为当前进程创建新的数据库连接和API连接 engine = create_engine_from_config(config) ts.set_token(config['tushare_token']) pro = ts.pro_api() # 结果统计 results = { 'success_count': 0, 'failed_count': 0, 'no_new_data_count': 0, 'skipped_count': 0, 'processed_count': 0 } # 获取当前进程ID process_id = multiprocessing.current_process().name total_stocks = len(batch) for i, stock_info in enumerate(batch): # 当前更新时间 update_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S') # 根据更新模式获取不同的参数 if update_mode == 'full': ts_code = stock_info try: # 简化日志,只显示进度 print(f"[{process_id}] 进度 {i + 1}/{total_stocks}: 全量更新 {ts_code}") # 检查股票表是否存在,不存在则创建 symbol = ts_code.split('.')[0] with engine.connect() as conn: table_exists_query = f""" SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = '{config['mysql']['database']}' AND table_name = '{symbol}' """ table_exists = conn.execute(text(table_exists_query)).scalar() > 0 if not table_exists: create_stock_table(engine, ts_code) # 下载股票数据 stock_data = download_stock_data(pro, ts_code, start_date, end_date) # 如果下载成功,保存数据并更新元数据 if not stock_data.empty: # 使用replace模式保存数据(全量替换) records_saved = save_stock_data(stock_data, engine, ts_code) # 更新元数据统计信息 update_metadata(engine, ts_code) # 更新全量更新时间戳和状态为正常(1) with engine.connect() as conn: update_status_sql = f""" UPDATE stock_metadata SET last_full_update = '{update_time}', status = 1 WHERE ts_code = '{ts_code}' """ conn.execute(text(update_status_sql)) conn.commit() results['success_count'] += 1 else: # 更新状态为全量更新失败(2) with engine.connect() as conn: update_status_sql = f""" UPDATE stock_metadata SET status = 2 WHERE ts_code = '{ts_code}' """ conn.execute(text(update_status_sql)) conn.commit() results['failed_count'] += 1 # 防止API限流 import time time.sleep(0.5) except Exception as e: # 发生异常时,更新状态为全量更新失败(2) try: with engine.connect() as conn: update_status_sql = f""" UPDATE stock_metadata SET status = 2 WHERE ts_code = '{ts_code}' """ conn.execute(text(update_status_sql)) conn.commit() except Exception: pass # 更新状态失败时,忽略继续执行 results['failed_count'] += 1 print(f"[{process_id}] 股票 {ts_code} 更新失败: {str(e)}") elif update_mode == 'incremental': ts_code = stock_info['ts_code'] start_date = stock_info['start_date'] try: # 简化日志,只显示进度 print(f"[{process_id}] 进度 {i + 1}/{total_stocks}: 增量更新 {ts_code}") # 确保股票表存在 symbol = ts_code.split('.')[0] with engine.connect() as conn: table_exists_query = f""" SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = '{config['mysql']['database']}' AND table_name = '{symbol}' """ table_exists = conn.execute(text(table_exists_query)).scalar() > 0 if not table_exists: results['skipped_count'] += 1 continue # 表不存在则跳过 # 下载增量数据 new_data = download_stock_data(pro, ts_code, start_date, end_date) if not new_data.empty: # 获取现有数据 try: with engine.connect() as conn: existing_data = pd.read_sql(f"SELECT * FROM `{symbol}`", conn) except Exception as e: results['skipped_count'] += 1 continue if not existing_data.empty: # 转换日期列为相同的格式以便合并 existing_data['trade_date'] = pd.to_datetime(existing_data['trade_date']) new_data['trade_date'] = pd.to_datetime(new_data['trade_date']) # 删除可能重复的日期记录 existing_dates = set(existing_data['trade_date']) new_data = new_data[~new_data['trade_date'].isin(existing_dates)] if new_data.empty: results['no_new_data_count'] += 1 continue # 合并数据 combined_data = pd.concat([existing_data, new_data], ignore_index=True) # 按日期排序 combined_data = combined_data.sort_values('trade_date') # 保存合并后的数据 records_saved = save_stock_data(combined_data, engine, ts_code) else: # 如果表存在但为空,直接保存新数据 records_saved = save_stock_data(new_data, engine, ts_code) # 更新元数据统计信息 update_metadata(engine, ts_code) # 更新增量更新时间戳和状态为正常(1) with engine.connect() as conn: update_status_sql = f""" UPDATE stock_metadata SET last_incremental_update = '{update_time}', status = 1 WHERE ts_code = '{ts_code}' """ conn.execute(text(update_status_sql)) conn.commit() results['success_count'] += 1 else: # 无新数据 results['no_new_data_count'] += 1 # 防止API限流 import time time.sleep(0.5) except Exception as e: # 发生异常时,更新状态为增量更新失败(3) try: with engine.connect() as conn: update_status_sql = f""" UPDATE stock_metadata SET status = 3 WHERE ts_code = '{ts_code}' """ conn.execute(text(update_status_sql)) conn.commit() except Exception: pass # 更新状态失败时,忽略继续执行 results['failed_count'] += 1 print(f"[{process_id}] 股票 {ts_code} 增量更新失败: {str(e)}") results['processed_count'] += 1 return results def perform_full_update(start_year=2020, processes=4, resume=False): """ 执行全量更新:从指定年份开始到今天的所有数据 Args: start_year: 开始年份,默认为2020年 processes: 进程数量,默认为4 resume: 是否只更新之前全量更新失败的股票 """ update_type = "续传全量更新" if resume else "全量更新" print(f"开始执行{update_type} (从{start_year}年至今),使用 {processes} 个进程...") start_time = datetime.now() # 加载配置 config = load_config() # 创建数据库引擎 engine = create_engine_from_config(config) # 设置Tushare API ts.set_token(config['tushare_token']) pro = ts.pro_api() # 设置数据获取的日期范围 end_date = datetime.now().strftime('%Y%m%d') start_date = f"{start_year}0101" # 从指定年份的1月1日开始 print(f"数据范围: {start_date} 至 {end_date}") # 从元数据表获取股票代码 with engine.connect() as conn: # 根据resume参数决定查询条件 if resume: # 只查询全量更新失败的股票 query = """ SELECT ts_code FROM stock_metadata WHERE list_status = 'L' AND status != 1 """ else: # 查询所有已上市的股票 query = """ SELECT ts_code FROM stock_metadata WHERE list_status = 'L' """ result = conn.execute(text(query)) stock_codes = [row[0] for row in result] if not stock_codes: if resume: print("没有找到需要续传的股票") else: print("没有找到需要更新的股票,请先确保元数据表已初始化") return print(f"共找到 {len(stock_codes)} 只股票需要{update_type}") # 将股票列表分成多个批次 batches = split_list_into_chunks(stock_codes, processes) # 创建进程池 with Pool(processes=processes) as pool: # 使用partial函数固定部分参数 process_func = partial( process_stock_batch, config=config, update_mode='full', start_date=start_date, end_date=end_date ) # 启动多进程处理并收集结果 results = pool.map(process_func, batches) # 汇总所有进程的结果 success_count = sum(r['success_count'] for r in results) failed_count = sum(r['failed_count'] for r in results) end_time = datetime.now() duration = (end_time - start_time).total_seconds() / 60 # 转换为分钟 print(f"\n{update_type}完成!") print(f"总耗时: {duration:.2f} 分钟") print(f"成功更新: {success_count} 只股票") print(f"更新失败: {failed_count} 只股票") print(f"总计: {len(stock_codes)} 只股票") def update_metadata(engine, ts_code): """ 更新股票元数据中的统计信息,从股票表中直接获取数据 Args: engine: 数据库连接引擎 ts_code: 股票代码(带后缀,如000001.SZ) Returns: bool: 更新成功返回True,否则返回False """ try: # 提取股票代码(不含后缀) symbol = ts_code.split('.')[0] # 查询股票数据统计信息 stats_query = f""" SELECT MIN(trade_date) as min_date, MAX(trade_date) as max_date, COUNT(*) as record_count FROM `{symbol}` """ # 查询最新交易日数据 latest_data_query = f""" SELECT close, pe_ttm, pb, total_mv FROM `{symbol}` WHERE trade_date = (SELECT MAX(trade_date) FROM `{symbol}`) LIMIT 1 """ # 执行查询 with engine.connect() as conn: # 获取统计信息 stats_result = conn.execute(text(stats_query)).fetchone() if not stats_result: # 更新状态为异常(2) update_empty_sql = f""" UPDATE stock_metadata SET last_full_update = NOW(), record_count = 0, status = 2 WHERE ts_code = '{ts_code}' """ conn.execute(text(update_empty_sql)) conn.commit() return False # 处理日期字段 try: data_start_date = stats_result[0].strftime('%Y-%m-%d') if stats_result[0] else 'NULL' data_end_date = stats_result[1].strftime('%Y-%m-%d') if stats_result[1] else 'NULL' except AttributeError: # 处理日期可能是字符串而不是datetime对象的情况 data_start_date = stats_result[0] if stats_result[0] else 'NULL' data_end_date = stats_result[1] if stats_result[1] else 'NULL' record_count = stats_result[2] or 0 # 获取最新交易日数据 latest_data_result = conn.execute(text(latest_data_query)).fetchone() # 设置默认值并处理NULL default_value = 'NULL' latest_price = default_value latest_pe_ttm = default_value latest_pb = default_value latest_total_mv = default_value # 如果有最新数据,则更新相应字段 if latest_data_result: latest_price = str(latest_data_result[0]) if latest_data_result[0] is not None else default_value latest_pe_ttm = str(latest_data_result[1]) if latest_data_result[1] is not None else default_value latest_pb = str(latest_data_result[2]) if latest_data_result[2] is not None else default_value latest_total_mv = str(latest_data_result[3]) if latest_data_result[3] is not None else default_value # 构建更新SQL update_sql = f""" UPDATE stock_metadata SET data_start_date = CASE WHEN '{data_start_date}' = 'NULL' THEN NULL ELSE '{data_start_date}' END, data_end_date = CASE WHEN '{data_end_date}' = 'NULL' THEN NULL ELSE '{data_end_date}' END, record_count = {record_count}, status = CASE WHEN {record_count} > 0 THEN 1 ELSE 2 END, latest_price = CASE WHEN '{latest_price}' = 'NULL' THEN NULL ELSE {latest_price} END, latest_date = CASE WHEN '{data_end_date}' = 'NULL' THEN NULL ELSE '{data_end_date}' END, latest_pe_ttm = CASE WHEN '{latest_pe_ttm}' = 'NULL' THEN NULL ELSE {latest_pe_ttm} END, latest_pb = CASE WHEN '{latest_pb}' = 'NULL' THEN NULL ELSE {latest_pb} END, latest_total_mv = CASE WHEN '{latest_total_mv}' = 'NULL' THEN NULL ELSE {latest_total_mv} END WHERE ts_code = '{ts_code}' """ # 执行更新 conn.execute(text(update_sql)) conn.commit() return True except Exception as e: # 尝试将状态更新为元数据更新失败(-1) try: with engine.connect() as conn: error_update_sql = f""" UPDATE stock_metadata SET status = -1 WHERE ts_code = '{ts_code}' """ conn.execute(text(error_update_sql)) conn.commit() except Exception: pass # 更新状态失败时,忽略继续执行 print(f"更新({ts_code})元数据时出错: {str(e)}") return False def perform_incremental_update(processes=4): """ 执行增量更新:从上次增量或全量更新日期到今天的数据 Args: processes: 进程数量,默认为4 """ print(f"开始执行增量更新,使用 {processes} 个进程...") start_time = datetime.now() # 加载配置 config = load_config() # 创建数据库引擎 engine = create_engine_from_config(config) # 设置Tushare API ts.set_token(config['tushare_token']) pro = ts.pro_api() # 当前日期作为结束日期 end_date = datetime.now().strftime('%Y%m%d') # 获取需要更新的股票及其上次更新日期 stocks_to_update = [] with engine.connect() as conn: query = """ SELECT ts_code, last_incremental_update, last_full_update FROM stock_metadata WHERE list_status = 'L' """ result = conn.execute(text(query)) for row in result: ts_code = row[0] last_incr_update = row[1] # 上次增量更新日期 last_full_update = row[2] # 上次全量更新日期 # 确定起始日期:使用最近的更新日期 latest_update = None # 检查增量更新日期 if last_incr_update: latest_update = last_incr_update # 检查全量更新日期 if last_full_update: if not latest_update or last_full_update > latest_update: latest_update = last_full_update # 如果有更新日期,使用其后一天作为起始日期 if latest_update: start_date = (latest_update + timedelta(days=1)).strftime('%Y%m%d') else: # 如果没有任何更新记录,默认从30天前开始 start_date = (datetime.now() - timedelta(days=30)).strftime('%Y%m%d') # 如果起始日期大于等于结束日期,则跳过此股票 if start_date >= end_date: continue stocks_to_update.append({ 'ts_code': ts_code, 'start_date': start_date }) if not stocks_to_update: print("没有找到需要增量更新的股票") return print(f"共找到 {len(stocks_to_update)} 只股票需要增量更新") # 当前更新时间 update_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S') # 将股票列表分成多个批次 batches = split_list_into_chunks(stocks_to_update, processes) # 创建进程池 with Pool(processes=processes) as pool: # 使用partial函数固定部分参数 process_func = partial( process_stock_batch, config=config, update_mode='incremental', start_date=None, # 增量模式下使用每个股票自己的起始日期 end_date=end_date, update_time=update_time ) # 启动多进程处理并收集结果 results = pool.map(process_func, batches) # 汇总所有进程的结果 success_count = sum(r['success_count'] for r in results) failed_count = sum(r['failed_count'] for r in results) no_new_data_count = sum(r['no_new_data_count'] for r in results) skipped_count = sum(r['skipped_count'] for r in results) end_time = datetime.now() duration = (end_time - start_time).total_seconds() / 60 # 转换为分钟 print("\n增量更新完成!") print(f"总耗时: {duration:.2f} 分钟") print(f"成功更新: {success_count} 只股票") print(f"无新数据: {no_new_data_count} 只股票") print(f"跳过股票: {skipped_count} 只股票") print(f"更新失败: {failed_count} 只股票") print(f"总计: {len(stocks_to_update)} 只股票") # 添加创建日历表的函数 def create_calendar_table(engine): """创建交易日历表""" # 从sql文件读取创建表的SQL with open('sql/calendar.sql', 'r', encoding='utf-8') as f: create_table_sql = f.read() with engine.connect() as conn: conn.execute(text(create_table_sql)) conn.commit() print("交易日历表创建成功") # 添加下载日历数据的函数 def download_calendar_data(pro, start_date=None, end_date=None): """ 下载交易日历数据 Args: pro: Tushare API对象 start_date: 开始日期,格式YYYYMMDD,默认为2000年初 end_date: 结束日期,格式YYYYMMDD,默认为当前日期之后的一年 Returns: pandas.DataFrame: 交易日历数据 """ # 默认获取从2000年至今后一年的日历数据 if start_date is None: start_date = '20000101' if end_date is None: # 获取当前日期后一年的数据,以便预先有足够的交易日历 current_date = datetime.now() next_year = current_date + timedelta(days=365) end_date = next_year.strftime('%Y%m%d') print(f"正在下载从 {start_date} 到 {end_date} 的交易日历数据...") try: # 调用Tushare的trade_cal接口获取交易日历 df = pro.trade_cal(exchange='SSE', start_date=start_date, end_date=end_date) # 重命名列以符合数据库表结构 df = df.rename(columns={ 'cal_date': 'trade_date', 'exchange': 'exchange', 'is_open': 'is_open', 'pretrade_date': 'pretrade_date' }) # 将日期格式从YYYYMMDD转换为YYYY-MM-DD df['trade_date'] = pd.to_datetime(df['trade_date']).dt.strftime('%Y-%m-%d') if 'pretrade_date' in df.columns: df['pretrade_date'] = pd.to_datetime(df['pretrade_date']).dt.strftime('%Y-%m-%d') # 添加创建和更新时间 current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S') df['create_time'] = current_time df['update_time'] = current_time print(f"成功下载 {len(df)} 条交易日历记录") return df except Exception as e: print(f"下载交易日历数据时出错: {e}") return pd.DataFrame() # 添加保存日历数据的函数 def save_calendar_data(calendar_df, engine): """ 将交易日历数据保存到数据库 Args: calendar_df: 包含交易日历数据的DataFrame engine: SQLAlchemy引擎 """ if calendar_df.empty: print("没有交易日历数据可保存") return try: # 使用to_sql将数据写入数据库,如果表存在则替换 calendar_df.to_sql('calendar', engine, if_exists='replace', index=False) print(f"成功保存 {len(calendar_df)} 条交易日历数据到数据库") except Exception as e: print(f"保存交易日历数据时出错: {e}") # 添加交易日历更新函数 def update_calendar(engine, pro): """ 更新交易日历数据 Args: engine: SQLAlchemy引擎 pro: Tushare API对象 """ try: # 检查calendar表是否存在 with engine.connect() as conn: result = conn.execute(text("SHOW TABLES LIKE 'calendar'")) table_exists = result.fetchone() is not None if not table_exists: # 如果表不存在,创建表并下载所有历史数据 create_calendar_table(engine) calendar_df = download_calendar_data(pro) save_calendar_data(calendar_df, engine) else: # 如果表存在,查询最新的交易日期 with engine.connect() as conn: result = conn.execute(text("SELECT MAX(trade_date) FROM calendar")) latest_date = result.fetchone()[0] if latest_date: # 将日期转换为YYYYMMDD格式 start_date = (datetime.strptime(str(latest_date), '%Y-%m-%d') + timedelta(days=1)).strftime('%Y%m%d') # 下载新数据 calendar_df = download_calendar_data(pro, start_date=start_date) if not calendar_df.empty: # 将新数据追加到表中 calendar_df.to_sql('calendar', engine, if_exists='append', index=False) print(f"成功更新 {len(calendar_df)} 条新的交易日历数据") else: print("没有新的交易日历数据需要更新") else: # 如果表为空,下载所有历史数据 calendar_df = download_calendar_data(pro) save_calendar_data(calendar_df, engine) except Exception as e: print(f"更新交易日历时出错: {e}") def main(): parser = argparse.ArgumentParser(description='股票数据下载器') parser.add_argument('--init', action='store_true', help='初始化元数据') parser.add_argument('--mode', choices=['full', 'incremental', 'both'], default='full', help='更新模式: full(全量), incremental(增量), both(两者都)') parser.add_argument('--year', type=int, default=2020, help='起始年份(用于全量更新)') parser.add_argument('--processes', type=int, default=4, help='并行进程数') parser.add_argument('--resume', action='store_true', help='是否从中断处继续') parser.add_argument('--update-calendar', action='store_true', help='更新交易日历数据') args = parser.parse_args() # 加载配置 config = load_config() # 创建数据库引擎 engine = create_engine_from_config(config) # 设置Tushare API ts.set_token(config['tushare_token']) pro = ts.pro_api() # 如果需要初始化元数据 if args.init: # 创建元数据表 create_metadata_table(engine) # 下载并保存股票元数据 metadata_df = download_stock_metadata(pro) save_metadata(metadata_df, engine) print("元数据初始化完成") # 初始化交易日历数据 update_calendar(engine, pro) print("交易日历初始化完成") # 如果需要更新交易日历 if args.update_calendar: update_calendar(engine, pro) print("交易日历更新完成") # 根据选择的模式执行更新 if args.mode == 'full' or args.mode == 'both': perform_full_update(args.year, args.processes, args.resume) if args.mode == 'incremental' or args.mode == 'both': perform_incremental_update(args.processes) if __name__ == '__main__': main()