diff --git a/data_downloader.py b/data_downloader.py index aa57fbc..1370e34 100644 --- a/data_downloader.py +++ b/data_downloader.py @@ -10,7 +10,7 @@ from functools import partial from sqlalchemy import create_engine, text from datetime import datetime, timedelta -from utils import load_config, create_engine_from_config +from utils import load_config, get_engine def create_metadata_table(engine): @@ -351,7 +351,7 @@ def process_stock_batch(batch, config, update_mode, start_date, end_date): dict: 包含处理结果统计信息的字典 """ # 为当前进程创建新的数据库连接和API连接 - engine = create_engine_from_config(config) + engine = get_engine() ts.set_token(config['tushare_token']) pro = ts.pro_api() @@ -567,7 +567,7 @@ def perform_full_update(start_year=2020, processes=4, resume=False): config = load_config() # 创建数据库引擎 - engine = create_engine_from_config(config) + engine = get_engine() # 设置Tushare API ts.set_token(config['tushare_token']) @@ -767,7 +767,7 @@ def perform_incremental_update(processes=4): config = load_config() # 创建数据库引擎 - engine = create_engine_from_config(config) + engine = get_engine() # 设置Tushare API ts.set_token(config['tushare_token']) @@ -787,7 +787,7 @@ def perform_incremental_update(processes=4): FROM stock_metadata WHERE - list_status = 'L' + list_status = 'L' AND status = 1 """ result = conn.execute(text(query)) for row in result: @@ -807,12 +807,8 @@ def perform_incremental_update(processes=4): if not latest_update or last_full_update > latest_update: latest_update = last_full_update - # 如果有更新日期,使用其后一天作为起始日期 - if latest_update: - start_date = (latest_update + timedelta(days=1)).strftime('%Y%m%d') - else: - # 如果没有任何更新记录,默认从30天前开始 - start_date = (datetime.now() - timedelta(days=30)).strftime('%Y%m%d') + # 使用全量/增量最晚的一个时间 + start_date = latest_update.strftime('%Y%m%d') # 如果起始日期大于等于结束日期,则跳过此股票 if start_date >= end_date: @@ -829,9 +825,6 @@ def perform_incremental_update(processes=4): print(f"共找到 {len(stocks_to_update)} 只股票需要增量更新") - # 当前更新时间 - update_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S') - # 将股票列表分成多个批次 batches = split_list_into_chunks(stocks_to_update, processes) @@ -843,8 +836,7 @@ def perform_incremental_update(processes=4): config=config, update_mode='incremental', start_date=None, # 增量模式下使用每个股票自己的起始日期 - end_date=end_date, - update_time=update_time + end_date=end_date ) # 启动多进程处理并收集结果 @@ -1019,7 +1011,7 @@ def main(): # 加载配置 config = load_config() # 创建数据库引擎 - engine = create_engine_from_config(config) + engine = get_engine() # 设置Tushare API ts.set_token(config['tushare_token']) pro = ts.pro_api() diff --git a/utils.py b/utils.py index 0c33b9d..2d466ed 100644 --- a/utils.py +++ b/utils.py @@ -6,8 +6,16 @@ from datetime import datetime, timedelta from sqlalchemy import create_engine +# 模块级单例 +_config = None +_engine = None + def load_config(): + global _config + if _config is not None: + return _config + config_path = 'config.yaml' if not os.path.exists(config_path): config = { @@ -27,11 +35,16 @@ def load_config(): exit(1) with open(config_path, 'r') as f: - config = yaml.safe_load(f) - return config + _config = yaml.safe_load(f) + return _config -def create_engine_from_config(config): - mysql = config['mysql'] - connection_string = f"mysql+pymysql://{mysql['user']}:{mysql['password']}@{mysql['host']}:{mysql['port']}/{mysql['database']}?charset={mysql['charset']}&use_unicode=1" - return create_engine(connection_string) +def get_engine(): + """获取单例数据库引擎""" + global _engine + if _engine is None: + config = load_config() + mysql = config['mysql'] + connection_string = f"mysql+pymysql://{mysql['user']}:{mysql['password']}@{mysql['host']}:{mysql['port']}/{mysql['database']}?charset={mysql['charset']}&use_unicode=1" + _engine = create_engine(connection_string) + return _engine \ No newline at end of file