diff --git a/data_downloader.py b/data_downloader.py index 08e702c..43c46b6 100644 --- a/data_downloader.py +++ b/data_downloader.py @@ -281,8 +281,19 @@ def download_stock_data(pro, ts_code, start_date, end_date): return result_df -def save_stock_data(df, engine, ts_code): - """保存股票数据到对应的表中""" +def save_stock_data(df, engine, ts_code, if_exists='replace'): + """ + 保存股票数据到对应的表中 + + Args: + df: 股票数据DataFrame + engine: 数据库引擎 + ts_code: 股票代码 + if_exists: 如果表已存在,处理方式。'replace': 替换现有表,'append': 附加到现有表 + + Returns: + int: 保存的记录数量 + """ if df.empty: print(f"警告:{ts_code}没有数据可保存") return 0 @@ -291,12 +302,12 @@ def save_stock_data(df, engine, ts_code): if 'ts_code' in df.columns: df = df.drop(columns=['ts_code']) - # 使用replace模式,确保表中只包含最新数据 + # 使用指定的模式保存数据 try: symbol = ts_code.split('.')[0] result = df.to_sql(f'{symbol}', engine, index=False, - if_exists='replace', chunksize=1000) - print(f"成功保存{result}条{ts_code}股票数据") + if_exists=if_exists, chunksize=1000) + print(f"成功保存{result}条{ts_code}股票数据 (模式: {if_exists})") return result except Exception as e: print(f"保存{ts_code}数据时出错: {str(e)}") @@ -304,74 +315,129 @@ def save_stock_data(df, engine, ts_code): def update_metadata(engine, ts_code): - """更新股票元数据中的统计信息,从股票表中直接获取数据""" - # 提取股票代码(不含后缀) - symbol = ts_code.split('.')[0] - - # 查询股票数据统计信息 - stats_query = f""" - SELECT - MIN(trade_date) as min_date, - MAX(trade_date) as max_date, - COUNT(*) as record_count - FROM `{symbol}` """ + 更新股票元数据中的统计信息,从股票表中直接获取数据 - # 查询最新交易日数据 - latest_data_query = f""" - SELECT - close, - pe_ttm, - pb, - total_mv - FROM `{symbol}` - WHERE trade_date = (SELECT MAX(trade_date) FROM `{symbol}`) - LIMIT 1 + Args: + engine: 数据库连接引擎 + ts_code: 股票代码(带后缀,如000001.SZ) + + Returns: + bool: 更新成功返回True,否则返回False """ + try: + # 提取股票代码(不含后缀) + symbol = ts_code.split('.')[0] - # 执行查询 - with engine.connect() as conn: - # 获取统计信息 - stats_result = conn.execute(text(stats_query)).fetchone() - if not stats_result: - print(f"警告:{ts_code}没有数据,无法更新元数据") - return + # 查询股票数据统计信息 + stats_query = f""" + SELECT + MIN(trade_date) as min_date, + MAX(trade_date) as max_date, + COUNT(*) as record_count + FROM `{symbol}` + """ - data_start_date = stats_result[0].strftime('%Y-%m-%d') if stats_result[0] else 'NULL' - data_end_date = stats_result[1].strftime('%Y-%m-%d') if stats_result[1] else 'NULL' - record_count = stats_result[2] + # 查询最新交易日数据 + latest_data_query = f""" + SELECT + close, + pe_ttm, + pb, + total_mv + FROM `{symbol}` + WHERE trade_date = (SELECT MAX(trade_date) FROM `{symbol}`) + LIMIT 1 + """ - # 获取最新交易日数据 - latest_data_result = conn.execute(text(latest_data_query)).fetchone() - latest_price = f"{latest_data_result[0]}" if latest_data_result and latest_data_result[ - 0] is not None else 'NULL' - latest_pe_ttm = f"{latest_data_result[1]}" if latest_data_result and latest_data_result[ - 1] is not None else 'NULL' - latest_pb = f"{latest_data_result[2]}" if latest_data_result and latest_data_result[2] is not None else 'NULL' - latest_total_mv = f"{latest_data_result[3]}" if latest_data_result and latest_data_result[ - 3] is not None else 'NULL' + # 执行查询 + with engine.connect() as conn: + # 获取统计信息 + stats_result = conn.execute(text(stats_query)).fetchone() + if not stats_result: + print(f"警告:{ts_code}没有数据,将更新元数据状态为异常") + # 更新状态为异常 + update_empty_sql = f""" + UPDATE stock_metadata + SET last_full_update = NOW(), + record_count = 0, + status = 2 + WHERE ts_code = '{ts_code}' + """ + conn.execute(text(update_empty_sql)) + conn.commit() + return False - # 加载更新SQL模板 - with open('sql/update_metadata.sql', 'r', encoding='utf-8') as f: - update_sql_template = f.read() + # 处理日期字段 + try: + data_start_date = stats_result[0].strftime('%Y-%m-%d') if stats_result[0] else 'NULL' + data_end_date = stats_result[1].strftime('%Y-%m-%d') if stats_result[1] else 'NULL' + except AttributeError: + # 处理日期可能是字符串而不是datetime对象的情况 + data_start_date = stats_result[0] if stats_result[0] else 'NULL' + data_end_date = stats_result[1] if stats_result[1] else 'NULL' - # 填充模板 - update_sql = update_sql_template.format( - data_start_date=data_start_date, - data_end_date=data_end_date, - record_count=record_count, - latest_price=latest_price, - latest_pe_ttm=latest_pe_ttm, - latest_pb=latest_pb, - latest_total_mv=latest_total_mv, - ts_code=ts_code - ) + record_count = stats_result[2] or 0 - # 执行更新 - conn.execute(text(update_sql)) - conn.commit() + # 获取最新交易日数据 + latest_data_result = conn.execute(text(latest_data_query)).fetchone() - print(f"已更新({ts_code})的元数据信息") + # 设置默认值并处理NULL + default_value = 'NULL' + latest_price = default_value + latest_pe_ttm = default_value + latest_pb = default_value + latest_total_mv = default_value + + # 如果有最新数据,则更新相应字段 + if latest_data_result: + latest_price = str(latest_data_result[0]) if latest_data_result[0] is not None else default_value + latest_pe_ttm = str(latest_data_result[1]) if latest_data_result[1] is not None else default_value + latest_pb = str(latest_data_result[2]) if latest_data_result[2] is not None else default_value + latest_total_mv = str(latest_data_result[3]) if latest_data_result[3] is not None else default_value + + # 构建更新SQL + update_sql = f""" + UPDATE stock_metadata + SET + last_full_update = NOW(), + data_start_date = CASE WHEN '{data_start_date}' = 'NULL' THEN NULL ELSE '{data_start_date}' END, + data_end_date = CASE WHEN '{data_end_date}' = 'NULL' THEN NULL ELSE '{data_end_date}' END, + record_count = {record_count}, + status = CASE WHEN {record_count} > 0 THEN 1 ELSE 2 END, + latest_price = CASE WHEN '{latest_price}' = 'NULL' THEN NULL ELSE {latest_price} END, + latest_date = CASE WHEN '{data_end_date}' = 'NULL' THEN NULL ELSE '{data_end_date}' END, + latest_pe_ttm = CASE WHEN '{latest_pe_ttm}' = 'NULL' THEN NULL ELSE {latest_pe_ttm} END, + latest_pb = CASE WHEN '{latest_pb}' = 'NULL' THEN NULL ELSE {latest_pb} END, + latest_total_mv = CASE WHEN '{latest_total_mv}' = 'NULL' THEN NULL ELSE {latest_total_mv} END + WHERE ts_code = '{ts_code}' + """ + + # 执行更新 + conn.execute(text(update_sql)) + conn.commit() + + print(f"已更新({ts_code})的元数据信息") + return True + + except Exception as e: + print(f"更新({ts_code})元数据时出错: {str(e)}") + + # 尝试将状态更新为异常 + try: + with engine.connect() as conn: + error_update_sql = f""" + UPDATE stock_metadata + SET status = 2, + last_full_update = NOW() + WHERE ts_code = '{ts_code}' + """ + conn.execute(text(error_update_sql)) + conn.commit() + except Exception as inner_e: + print(f"更新({ts_code})状态为异常时出错: {str(inner_e)}") + + return False def process_stock(engine, pro, ts_code, start_date, end_date): @@ -391,7 +457,16 @@ def process_stock(engine, pro, ts_code, start_date, end_date): print(f"警告:({ts_code})没有获取到数据") -def main(): +def perform_full_update(start_year=2020): + """ + 执行全量更新:从指定年份开始到今天的所有数据 + + Args: + start_year: 开始年份,默认为2020年 + """ + print(f"开始执行全量更新 (从{start_year}年至今)...") + start_time = datetime.now() + # 加载配置 config = load_config() @@ -402,22 +477,328 @@ def main(): ts.set_token(config['tushare_token']) pro = ts.pro_api() - # 创建元数据表 - create_metadata_table(engine) - - # 下载并保存股票元数据 - metadata_df = download_stock_metadata(pro) - save_metadata(metadata_df, engine) - - # 示例:处理000001股票 - ts_code = '000001.SZ' # 平安银行的ts_code - - # 设置数据获取的日期范围(示例:过去5年的数据) + # 设置数据获取的日期范围 end_date = datetime.now().strftime('%Y%m%d') - start_date = (datetime.now() - timedelta(days=15 * 365)).strftime('%Y%m%d') + start_date = f"{start_year}0101" # 从指定年份的1月1日开始 - # 处理该股票 - process_stock(engine, pro, ts_code, start_date, end_date) + print(f"数据范围: {start_date} 至 {end_date}") + + # 从元数据表获取所有股票代码 + with engine.connect() as conn: + # 查询所有需要更新的股票代码 + query = """ + SELECT ts_code + FROM stock_metadata + WHERE list_status = 'L' -- 只选择已上市的股票 + """ + result = conn.execute(text(query)) + stock_codes = [row[0] for row in result] + + if not stock_codes: + print("没有找到需要更新的股票,请先确保元数据表已初始化") + return + + print(f"共找到 {len(stock_codes)} 只股票需要更新") + + # 记录成功和失败的股票数量 + success_count = 0 + failed_count = 0 + skipped_count = 0 + + # 当前更新时间 + update_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S') + + # 逐一处理每只股票 + total_stocks = len(stock_codes) + for index, ts_code in enumerate(stock_codes): + try: + print(f"[{index + 1}/{total_stocks}] 正在全量更新股票: {ts_code}") + + # 检查股票表是否存在,不存在则创建 + symbol = ts_code.split('.')[0] + with engine.connect() as conn: + table_exists_query = f""" + SELECT COUNT(*) + FROM information_schema.tables + WHERE table_schema = '{config['mysql']['database']}' + AND table_name = '{symbol}' + """ + table_exists = conn.execute(text(table_exists_query)).scalar() > 0 + + if not table_exists: + create_stock_table(engine, ts_code) + + # 下载股票数据 + stock_data = download_stock_data(pro, ts_code, start_date, end_date) + + # 如果下载成功,保存数据并更新元数据 + if not stock_data.empty: + # 使用replace模式保存数据(全量替换) + records_saved = save_stock_data(stock_data, engine, ts_code) + + # 更新元数据统计信息 + update_metadata(engine, ts_code) + + # 只更新全量更新时间戳 + with engine.connect() as conn: + update_status_sql = f""" + UPDATE stock_metadata + SET last_full_update = '{update_time}' + WHERE ts_code = '{ts_code}' + """ + conn.execute(text(update_status_sql)) + conn.commit() + + success_count += 1 + print(f"成功更新 {ts_code},保存了 {records_saved} 条记录") + else: + failed_count += 1 + print(f"警告:{ts_code} 没有获取到数据") + + # 防止API限流,每次请求后短暂休息 + import time + time.sleep(0.5) + + except Exception as e: + failed_count += 1 + print(f"处理股票 {ts_code} 时出错: {str(e)}") + + end_time = datetime.now() + duration = (end_time - start_time).total_seconds() / 60 # 转换为分钟 + + print("\n全量更新完成!") + print(f"总耗时: {duration:.2f} 分钟") + print(f"成功更新: {success_count} 只股票") + print(f"更新失败: {failed_count} 只股票") + print(f"总计: {total_stocks} 只股票") + + +def perform_incremental_update(): + """ + 执行增量更新:从上次增量或全量更新日期到今天的数据 + - 使用最近的更新日期(增量或全量)作为起点 + - 如果表不存在,跳过该股票 + - 只更新增量更新时间字段 + """ + print("开始执行增量更新...") + start_time = datetime.now() + + # 加载配置 + config = load_config() + + # 创建数据库引擎 + engine = create_engine_from_config(config) + + # 设置Tushare API + ts.set_token(config['tushare_token']) + pro = ts.pro_api() + + # 当前日期作为结束日期 + end_date = datetime.now().strftime('%Y%m%d') + + # 获取需要更新的股票及其上次更新日期 + stocks_to_update = [] + with engine.connect() as conn: + query = """ + SELECT + ts_code, + last_incremental_update, + last_full_update + FROM + stock_metadata + WHERE + list_status = 'L' AND + status != 2 -- 排除状态异常的股票 + """ + result = conn.execute(text(query)) + + for row in result: + ts_code = row[0] + last_incr_update = row[1] # 上次增量更新日期 + last_full_update = row[2] # 上次全量更新日期 + + # 确定起始日期:使用最近的更新日期 + latest_update = None + + # 检查增量更新日期 + if last_incr_update: + latest_update = last_incr_update + + # 检查全量更新日期 + if last_full_update: + if not latest_update or last_full_update > latest_update: + latest_update = last_full_update + + # 如果有更新日期,使用其后一天作为起始日期 + if latest_update: + start_date = (latest_update + timedelta(days=1)).strftime('%Y%m%d') + else: + # 如果没有任何更新记录,默认从30天前开始 + start_date = (datetime.now() - timedelta(days=30)).strftime('%Y%m%d') + + # 如果起始日期大于等于结束日期,则跳过此股票 + if start_date >= end_date: + print(f"股票 {ts_code} 数据已是最新,无需更新") + continue + + stocks_to_update.append({ + 'ts_code': ts_code, + 'start_date': start_date + }) + + if not stocks_to_update: + print("没有找到需要增量更新的股票") + return + + print(f"共找到 {len(stocks_to_update)} 只股票需要增量更新") + + # 记录成功和失败的股票数量 + success_count = 0 + failed_count = 0 + no_new_data_count = 0 + skipped_count = 0 + + # 当前更新时间 + update_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S') + + # 逐一处理每只股票 + total_stocks = len(stocks_to_update) + for index, stock in enumerate(stocks_to_update): + ts_code = stock['ts_code'] + start_date = stock['start_date'] + + try: + print(f"[{index + 1}/{total_stocks}] 正在增量更新股票: {ts_code} (从 {start_date} 到 {end_date})") + + # 确保股票表存在 + symbol = ts_code.split('.')[0] + with engine.connect() as conn: + table_exists_query = f""" + SELECT COUNT(*) + FROM information_schema.tables + WHERE table_schema = '{config['mysql']['database']}' + AND table_name = '{symbol}' + """ + table_exists = conn.execute(text(table_exists_query)).scalar() > 0 + + if not table_exists: + print(f"股票 {ts_code} 数据表不存在,跳过此股票") + skipped_count += 1 + continue # 表不存在则跳过 + + # 下载增量数据 + new_data = download_stock_data(pro, ts_code, start_date, end_date) + + if not new_data.empty: + # 获取现有数据 + try: + with engine.connect() as conn: + existing_data = pd.read_sql(f"SELECT * FROM `{symbol}`", conn) + except Exception as e: + print(f"读取现有数据失败: {str(e)},跳过此股票") + skipped_count += 1 + continue + + if not existing_data.empty: + # 转换日期列为相同的格式以便合并 + existing_data['trade_date'] = pd.to_datetime(existing_data['trade_date']) + new_data['trade_date'] = pd.to_datetime(new_data['trade_date']) + + # 删除可能重复的日期记录 + existing_dates = set(existing_data['trade_date']) + new_data = new_data[~new_data['trade_date'].isin(existing_dates)] + + if new_data.empty: + print(f"股票 {ts_code} 没有新数据需要更新") + no_new_data_count += 1 + continue + + # 合并数据 + combined_data = pd.concat([existing_data, new_data], ignore_index=True) + + # 按日期排序 + combined_data = combined_data.sort_values('trade_date') + + # 保存合并后的数据 + records_saved = save_stock_data(combined_data, engine, ts_code) + else: + # 如果表存在但为空,直接保存新数据 + records_saved = save_stock_data(new_data, engine, ts_code) + + # 更新元数据统计信息 + update_metadata(engine, ts_code) + + # 只更新增量更新时间戳 + with engine.connect() as conn: + update_status_sql = f""" + UPDATE stock_metadata + SET last_incremental_update = '{update_time}' + WHERE ts_code = '{ts_code}' + """ + conn.execute(text(update_status_sql)) + conn.commit() + + success_count += 1 + print(f"成功增量更新 {ts_code},新增 {len(new_data)} 条记录") + else: + print(f"股票 {ts_code} 在指定时间范围内没有新数据") + no_new_data_count += 1 + + # 防止API限流,每次请求后短暂休息 + import time + time.sleep(0.5) + + except Exception as e: + failed_count += 1 + print(f"增量更新股票 {ts_code} 时出错: {str(e)}") + + end_time = datetime.now() + duration = (end_time - start_time).total_seconds() / 60 # 转换为分钟 + + print("\n增量更新完成!") + print(f"总耗时: {duration:.2f} 分钟") + print(f"成功更新: {success_count} 只股票") + print(f"无新数据: {no_new_data_count} 只股票") + print(f"跳过股票: {skipped_count} 只股票") + print(f"更新失败: {failed_count} 只股票") + print(f"总计: {total_stocks} 只股票") + + +def main(): + """主函数,允许用户选择更新方式""" + import argparse + + # 命令行参数解析 + parser = argparse.ArgumentParser(description='股票数据更新工具') + parser.add_argument('--init', action='store_true', help='初始化元数据表') + parser.add_argument('--mode', choices=['full', 'incremental', 'both'], + default='full', help='更新模式: full=全量更新, incremental=增量更新, both=两者都执行') + parser.add_argument('--year', type=int, default=2020, help='全量更新的起始年份') + args = parser.parse_args() + + # 如果需要初始化元数据 + if args.init: + # 加载配置 + config = load_config() + # 创建数据库引擎 + engine = create_engine_from_config(config) + # 设置Tushare API + ts.set_token(config['tushare_token']) + pro = ts.pro_api() + + # 创建元数据表 + create_metadata_table(engine) + # 下载并保存股票元数据 + metadata_df = download_stock_metadata(pro) + save_metadata(metadata_df, engine) + print("元数据初始化完成") + + # 根据选择的模式执行更新 + if args.mode == 'full' or args.mode == 'both': + perform_full_update(args.year) + + if args.mode == 'incremental' or args.mode == 'both': + perform_incremental_update() if __name__ == '__main__': diff --git a/sql/update_metadata.sql b/sql/update_metadata.sql index 4cde69c..96a453a 100644 --- a/sql/update_metadata.sql +++ b/sql/update_metadata.sql @@ -2,13 +2,16 @@ UPDATE stock_metadata SET last_full_update = NOW(), - data_start_date = '{data_start_date}', - data_end_date = '{data_end_date}', + data_start_date = CASE WHEN '{data_start_date}' = 'NULL' THEN NULL ELSE '{data_start_date}' END, + data_end_date = CASE WHEN '{data_end_date}' = 'NULL' THEN NULL ELSE '{data_end_date}' END, record_count = {record_count}, - status = 1, - latest_price = {latest_price}, - latest_date = '{data_end_date}', - latest_pe_ttm = {latest_pe_ttm}, - latest_pb = {latest_pb}, - latest_total_mv = {latest_total_mv} + status = CASE + WHEN {record_count} > 0 THEN 1 -- 有数据,状态为正常 + ELSE 2 -- 无数据,状态为异常 + END, + latest_price = CASE WHEN '{latest_price}' = 'NULL' THEN NULL ELSE {latest_price} END, + latest_date = CASE WHEN '{data_end_date}' = 'NULL' THEN NULL ELSE '{data_end_date}' END, + latest_pe_ttm = CASE WHEN '{latest_pe_ttm}' = 'NULL' THEN NULL ELSE {latest_pe_ttm} END, + latest_pb = CASE WHEN '{latest_pb}' = 'NULL' THEN NULL ELSE {latest_pb} END, + latest_total_mv = CASE WHEN '{latest_total_mv}' = 'NULL' THEN NULL ELSE {latest_total_mv} END WHERE ts_code = '{ts_code}' \ No newline at end of file