From 77d990794bb3b41080ebfc1d4cf16fc934bd447b Mon Sep 17 00:00:00 2001 From: Qihang Zhang Date: Sat, 5 Apr 2025 22:39:24 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=9B=A0=EF=B8=8F=20feat(data=5Fdownloader)?= =?UTF-8?q?:=20=E6=B7=BB=E5=8A=A0=E4=BA=A4=E6=98=93=E6=97=A5=E5=8E=86?= =?UTF-8?q?=E8=A1=A8=E5=88=9B=E5=BB=BA=E3=80=81=E6=95=B0=E6=8D=AE=E4=B8=8B?= =?UTF-8?q?=E8=BD=BD=E3=80=81=E4=BF=9D=E5=AD=98=E4=B8=8E=E6=9B=B4=E6=96=B0?= =?UTF-8?q?=E5=8A=9F=E8=83=BD=20=E2=9E=95=20feat(sql):=20=E6=96=B0?= =?UTF-8?q?=E5=A2=9E=E4=BA=A4=E6=98=93=E6=97=A5=E5=8E=86=E8=A1=A8=E7=9A=84?= =?UTF-8?q?SQL=E5=88=9B=E5=BB=BA=E8=AF=AD=E5=8F=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- data_downloader.py | 182 ++++++++++++++++++++++++++++++++++++++++----- sql/calendar.sql | 12 +++ 2 files changed, 175 insertions(+), 19 deletions(-) create mode 100644 sql/calendar.sql diff --git a/data_downloader.py b/data_downloader.py index 9f1777c..58bbecf 100644 --- a/data_downloader.py +++ b/data_downloader.py @@ -1,3 +1,5 @@ +import argparse + import pandas as pd import tushare as ts import numpy as np @@ -871,31 +873,164 @@ def perform_incremental_update(processes=4): print(f"总计: {len(stocks_to_update)} 只股票") -def main(): - """主函数,允许用户选择更新方式""" - import argparse +# 添加创建日历表的函数 +def create_calendar_table(engine): + """创建交易日历表""" + # 从sql文件读取创建表的SQL + with open('sql/calendar.sql', 'r', encoding='utf-8') as f: + create_table_sql = f.read() - # 命令行参数解析 - parser = argparse.ArgumentParser(description='股票数据更新工具') - parser.add_argument('--init', action='store_true', help='初始化元数据表') - parser.add_argument('--mode', choices=['full', 'incremental', 'both'], - default='full', help='更新模式: full=全量更新, incremental=增量更新, both=两者都执行') - parser.add_argument('--year', type=int, default=2020, help='全量更新的起始年份') - parser.add_argument('--processes', type=int, default=8, help='使用的进程数量') - parser.add_argument('--resume', action='store_true', help='仅更新全量更新失败的股票,即status=2部分') + with engine.connect() as conn: + conn.execute(text(create_table_sql)) + conn.commit() + print("交易日历表创建成功") + + +# 添加下载日历数据的函数 +def download_calendar_data(pro, start_date=None, end_date=None): + """ + 下载交易日历数据 + + Args: + pro: Tushare API对象 + start_date: 开始日期,格式YYYYMMDD,默认为2000年初 + end_date: 结束日期,格式YYYYMMDD,默认为当前日期之后的一年 + + Returns: + pandas.DataFrame: 交易日历数据 + """ + # 默认获取从2000年至今后一年的日历数据 + if start_date is None: + start_date = '20000101' + + if end_date is None: + # 获取当前日期后一年的数据,以便预先有足够的交易日历 + current_date = datetime.now() + next_year = current_date + timedelta(days=365) + end_date = next_year.strftime('%Y%m%d') + + print(f"正在下载从 {start_date} 到 {end_date} 的交易日历数据...") + + try: + # 调用Tushare的trade_cal接口获取交易日历 + df = pro.trade_cal(exchange='SSE', start_date=start_date, end_date=end_date) + + # 重命名列以符合数据库表结构 + df = df.rename(columns={ + 'cal_date': 'trade_date', + 'exchange': 'exchange', + 'is_open': 'is_open', + 'pretrade_date': 'pretrade_date' + }) + + # 将日期格式从YYYYMMDD转换为YYYY-MM-DD + df['trade_date'] = pd.to_datetime(df['trade_date']).dt.strftime('%Y-%m-%d') + if 'pretrade_date' in df.columns: + df['pretrade_date'] = pd.to_datetime(df['pretrade_date']).dt.strftime('%Y-%m-%d') + + # 添加创建和更新时间 + current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S') + df['create_time'] = current_time + df['update_time'] = current_time + + print(f"成功下载 {len(df)} 条交易日历记录") + return df + + except Exception as e: + print(f"下载交易日历数据时出错: {e}") + return pd.DataFrame() + + +# 添加保存日历数据的函数 +def save_calendar_data(calendar_df, engine): + """ + 将交易日历数据保存到数据库 + + Args: + calendar_df: 包含交易日历数据的DataFrame + engine: SQLAlchemy引擎 + """ + if calendar_df.empty: + print("没有交易日历数据可保存") + return + + try: + # 使用to_sql将数据写入数据库,如果表存在则替换 + calendar_df.to_sql('calendar', engine, if_exists='replace', index=False) + print(f"成功保存 {len(calendar_df)} 条交易日历数据到数据库") + except Exception as e: + print(f"保存交易日历数据时出错: {e}") + + +# 添加交易日历更新函数 +def update_calendar(engine, pro): + """ + 更新交易日历数据 + + Args: + engine: SQLAlchemy引擎 + pro: Tushare API对象 + """ + try: + # 检查calendar表是否存在 + with engine.connect() as conn: + result = conn.execute(text("SHOW TABLES LIKE 'calendar'")) + table_exists = result.fetchone() is not None + + if not table_exists: + # 如果表不存在,创建表并下载所有历史数据 + create_calendar_table(engine) + calendar_df = download_calendar_data(pro) + save_calendar_data(calendar_df, engine) + else: + # 如果表存在,查询最新的交易日期 + with engine.connect() as conn: + result = conn.execute(text("SELECT MAX(trade_date) FROM calendar")) + latest_date = result.fetchone()[0] + + if latest_date: + # 将日期转换为YYYYMMDD格式 + start_date = (datetime.strptime(str(latest_date), '%Y-%m-%d') + timedelta(days=1)).strftime('%Y%m%d') + # 下载新数据 + calendar_df = download_calendar_data(pro, start_date=start_date) + + if not calendar_df.empty: + # 将新数据追加到表中 + calendar_df.to_sql('calendar', engine, if_exists='append', index=False) + print(f"成功更新 {len(calendar_df)} 条新的交易日历数据") + else: + print("没有新的交易日历数据需要更新") + else: + # 如果表为空,下载所有历史数据 + calendar_df = download_calendar_data(pro) + save_calendar_data(calendar_df, engine) + + except Exception as e: + print(f"更新交易日历时出错: {e}") + + +def main(): + parser = argparse.ArgumentParser(description='股票数据下载器') + parser.add_argument('--init', action='store_true', help='初始化元数据') + parser.add_argument('--mode', choices=['full', 'incremental', 'both'], default='full', + help='更新模式: full(全量), incremental(增量), both(两者都)') + parser.add_argument('--year', type=int, default=2020, help='起始年份(用于全量更新)') + parser.add_argument('--processes', type=int, default=4, help='并行进程数') + parser.add_argument('--resume', action='store_true', help='是否从中断处继续') + parser.add_argument('--update-calendar', action='store_true', help='更新交易日历数据') args = parser.parse_args() + # 加载配置 + config = load_config() + # 创建数据库引擎 + engine = create_engine_from_config(config) + # 设置Tushare API + ts.set_token(config['tushare_token']) + pro = ts.pro_api() + # 如果需要初始化元数据 if args.init: - # 加载配置 - config = load_config() - # 创建数据库引擎 - engine = create_engine_from_config(config) - # 设置Tushare API - ts.set_token(config['tushare_token']) - pro = ts.pro_api() - # 创建元数据表 create_metadata_table(engine) # 下载并保存股票元数据 @@ -903,6 +1038,15 @@ def main(): save_metadata(metadata_df, engine) print("元数据初始化完成") + # 初始化交易日历数据 + update_calendar(engine, pro) + print("交易日历初始化完成") + + # 如果需要更新交易日历 + if args.update_calendar: + update_calendar(engine, pro) + print("交易日历更新完成") + # 根据选择的模式执行更新 if args.mode == 'full' or args.mode == 'both': perform_full_update(args.year, args.processes, args.resume) diff --git a/sql/calendar.sql b/sql/calendar.sql new file mode 100644 index 0000000..0147ba3 --- /dev/null +++ b/sql/calendar.sql @@ -0,0 +1,12 @@ +CREATE TABLE IF NOT EXISTS `calendar` ( + trade_date DATE PRIMARY KEY COMMENT '交易日期', + exchange VARCHAR(10) COMMENT '交易所代码', + is_open TINYINT COMMENT '是否交易(0:休市 1:交易)', + pretrade_date DATE COMMENT '上一个交易日', + + create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间', + update_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间', + + INDEX idx_exchange (exchange), + INDEX idx_is_open (is_open) +) ENGINE=InnoDB COMMENT '交易日历表'; \ No newline at end of file