🔄 refactor(data_downloader): 优化数据库引擎创建逻辑,使用单例模式减少重复连接创建

🔄 refactor(utils): 实现配置加载的缓存,避免重复读取配置文件
This commit is contained in:
Qihang Zhang 2025-04-06 12:44:18 +08:00
parent e1c47616dd
commit 19718bd59f
2 changed files with 28 additions and 23 deletions

View File

@ -10,7 +10,7 @@ from functools import partial
from sqlalchemy import create_engine, text from sqlalchemy import create_engine, text
from datetime import datetime, timedelta from datetime import datetime, timedelta
from utils import load_config, create_engine_from_config from utils import load_config, get_engine
def create_metadata_table(engine): def create_metadata_table(engine):
@ -351,7 +351,7 @@ def process_stock_batch(batch, config, update_mode, start_date, end_date):
dict: 包含处理结果统计信息的字典 dict: 包含处理结果统计信息的字典
""" """
# 为当前进程创建新的数据库连接和API连接 # 为当前进程创建新的数据库连接和API连接
engine = create_engine_from_config(config) engine = get_engine()
ts.set_token(config['tushare_token']) ts.set_token(config['tushare_token'])
pro = ts.pro_api() pro = ts.pro_api()
@ -567,7 +567,7 @@ def perform_full_update(start_year=2020, processes=4, resume=False):
config = load_config() config = load_config()
# 创建数据库引擎 # 创建数据库引擎
engine = create_engine_from_config(config) engine = get_engine()
# 设置Tushare API # 设置Tushare API
ts.set_token(config['tushare_token']) ts.set_token(config['tushare_token'])
@ -767,7 +767,7 @@ def perform_incremental_update(processes=4):
config = load_config() config = load_config()
# 创建数据库引擎 # 创建数据库引擎
engine = create_engine_from_config(config) engine = get_engine()
# 设置Tushare API # 设置Tushare API
ts.set_token(config['tushare_token']) ts.set_token(config['tushare_token'])
@ -787,7 +787,7 @@ def perform_incremental_update(processes=4):
FROM FROM
stock_metadata stock_metadata
WHERE WHERE
list_status = 'L' list_status = 'L' AND status = 1
""" """
result = conn.execute(text(query)) result = conn.execute(text(query))
for row in result: for row in result:
@ -807,12 +807,8 @@ def perform_incremental_update(processes=4):
if not latest_update or last_full_update > latest_update: if not latest_update or last_full_update > latest_update:
latest_update = last_full_update latest_update = last_full_update
# 如果有更新日期,使用其后一天作为起始日期 # 使用全量/增量最晚的一个时间
if latest_update: start_date = latest_update.strftime('%Y%m%d')
start_date = (latest_update + timedelta(days=1)).strftime('%Y%m%d')
else:
# 如果没有任何更新记录默认从30天前开始
start_date = (datetime.now() - timedelta(days=30)).strftime('%Y%m%d')
# 如果起始日期大于等于结束日期,则跳过此股票 # 如果起始日期大于等于结束日期,则跳过此股票
if start_date >= end_date: if start_date >= end_date:
@ -829,9 +825,6 @@ def perform_incremental_update(processes=4):
print(f"共找到 {len(stocks_to_update)} 只股票需要增量更新") print(f"共找到 {len(stocks_to_update)} 只股票需要增量更新")
# 当前更新时间
update_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
# 将股票列表分成多个批次 # 将股票列表分成多个批次
batches = split_list_into_chunks(stocks_to_update, processes) batches = split_list_into_chunks(stocks_to_update, processes)
@ -843,8 +836,7 @@ def perform_incremental_update(processes=4):
config=config, config=config,
update_mode='incremental', update_mode='incremental',
start_date=None, # 增量模式下使用每个股票自己的起始日期 start_date=None, # 增量模式下使用每个股票自己的起始日期
end_date=end_date, end_date=end_date
update_time=update_time
) )
# 启动多进程处理并收集结果 # 启动多进程处理并收集结果
@ -1019,7 +1011,7 @@ def main():
# 加载配置 # 加载配置
config = load_config() config = load_config()
# 创建数据库引擎 # 创建数据库引擎
engine = create_engine_from_config(config) engine = get_engine()
# 设置Tushare API # 设置Tushare API
ts.set_token(config['tushare_token']) ts.set_token(config['tushare_token'])
pro = ts.pro_api() pro = ts.pro_api()

View File

@ -6,8 +6,16 @@ from datetime import datetime, timedelta
from sqlalchemy import create_engine from sqlalchemy import create_engine
# 模块级单例
_config = None
_engine = None
def load_config(): def load_config():
global _config
if _config is not None:
return _config
config_path = 'config.yaml' config_path = 'config.yaml'
if not os.path.exists(config_path): if not os.path.exists(config_path):
config = { config = {
@ -27,11 +35,16 @@ def load_config():
exit(1) exit(1)
with open(config_path, 'r') as f: with open(config_path, 'r') as f:
config = yaml.safe_load(f) _config = yaml.safe_load(f)
return config return _config
def create_engine_from_config(config): def get_engine():
mysql = config['mysql'] """获取单例数据库引擎"""
connection_string = f"mysql+pymysql://{mysql['user']}:{mysql['password']}@{mysql['host']}:{mysql['port']}/{mysql['database']}?charset={mysql['charset']}&use_unicode=1" global _engine
return create_engine(connection_string) if _engine is None:
config = load_config()
mysql = config['mysql']
connection_string = f"mysql+pymysql://{mysql['user']}:{mysql['password']}@{mysql['host']}:{mysql['port']}/{mysql['database']}?charset={mysql['charset']}&use_unicode=1"
_engine = create_engine(connection_string)
return _engine