🔄 refactor(data_downloader): 优化数据库引擎创建逻辑,使用单例模式减少重复连接创建
🔄 refactor(utils): 实现配置加载的缓存,避免重复读取配置文件
This commit is contained in:
parent
e1c47616dd
commit
19718bd59f
@ -10,7 +10,7 @@ from functools import partial
|
||||
from sqlalchemy import create_engine, text
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from utils import load_config, create_engine_from_config
|
||||
from utils import load_config, get_engine
|
||||
|
||||
|
||||
def create_metadata_table(engine):
|
||||
@ -351,7 +351,7 @@ def process_stock_batch(batch, config, update_mode, start_date, end_date):
|
||||
dict: 包含处理结果统计信息的字典
|
||||
"""
|
||||
# 为当前进程创建新的数据库连接和API连接
|
||||
engine = create_engine_from_config(config)
|
||||
engine = get_engine()
|
||||
ts.set_token(config['tushare_token'])
|
||||
pro = ts.pro_api()
|
||||
|
||||
@ -567,7 +567,7 @@ def perform_full_update(start_year=2020, processes=4, resume=False):
|
||||
config = load_config()
|
||||
|
||||
# 创建数据库引擎
|
||||
engine = create_engine_from_config(config)
|
||||
engine = get_engine()
|
||||
|
||||
# 设置Tushare API
|
||||
ts.set_token(config['tushare_token'])
|
||||
@ -767,7 +767,7 @@ def perform_incremental_update(processes=4):
|
||||
config = load_config()
|
||||
|
||||
# 创建数据库引擎
|
||||
engine = create_engine_from_config(config)
|
||||
engine = get_engine()
|
||||
|
||||
# 设置Tushare API
|
||||
ts.set_token(config['tushare_token'])
|
||||
@ -787,7 +787,7 @@ def perform_incremental_update(processes=4):
|
||||
FROM
|
||||
stock_metadata
|
||||
WHERE
|
||||
list_status = 'L'
|
||||
list_status = 'L' AND status = 1
|
||||
"""
|
||||
result = conn.execute(text(query))
|
||||
for row in result:
|
||||
@ -807,12 +807,8 @@ def perform_incremental_update(processes=4):
|
||||
if not latest_update or last_full_update > latest_update:
|
||||
latest_update = last_full_update
|
||||
|
||||
# 如果有更新日期,使用其后一天作为起始日期
|
||||
if latest_update:
|
||||
start_date = (latest_update + timedelta(days=1)).strftime('%Y%m%d')
|
||||
else:
|
||||
# 如果没有任何更新记录,默认从30天前开始
|
||||
start_date = (datetime.now() - timedelta(days=30)).strftime('%Y%m%d')
|
||||
# 使用全量/增量最晚的一个时间
|
||||
start_date = latest_update.strftime('%Y%m%d')
|
||||
|
||||
# 如果起始日期大于等于结束日期,则跳过此股票
|
||||
if start_date >= end_date:
|
||||
@ -829,9 +825,6 @@ def perform_incremental_update(processes=4):
|
||||
|
||||
print(f"共找到 {len(stocks_to_update)} 只股票需要增量更新")
|
||||
|
||||
# 当前更新时间
|
||||
update_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||
|
||||
# 将股票列表分成多个批次
|
||||
batches = split_list_into_chunks(stocks_to_update, processes)
|
||||
|
||||
@ -843,8 +836,7 @@ def perform_incremental_update(processes=4):
|
||||
config=config,
|
||||
update_mode='incremental',
|
||||
start_date=None, # 增量模式下使用每个股票自己的起始日期
|
||||
end_date=end_date,
|
||||
update_time=update_time
|
||||
end_date=end_date
|
||||
)
|
||||
|
||||
# 启动多进程处理并收集结果
|
||||
@ -1019,7 +1011,7 @@ def main():
|
||||
# 加载配置
|
||||
config = load_config()
|
||||
# 创建数据库引擎
|
||||
engine = create_engine_from_config(config)
|
||||
engine = get_engine()
|
||||
# 设置Tushare API
|
||||
ts.set_token(config['tushare_token'])
|
||||
pro = ts.pro_api()
|
||||
|
25
utils.py
25
utils.py
@ -6,8 +6,16 @@ from datetime import datetime, timedelta
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
|
||||
# 模块级单例
|
||||
_config = None
|
||||
_engine = None
|
||||
|
||||
|
||||
def load_config():
|
||||
global _config
|
||||
if _config is not None:
|
||||
return _config
|
||||
|
||||
config_path = 'config.yaml'
|
||||
if not os.path.exists(config_path):
|
||||
config = {
|
||||
@ -27,11 +35,16 @@ def load_config():
|
||||
exit(1)
|
||||
|
||||
with open(config_path, 'r') as f:
|
||||
config = yaml.safe_load(f)
|
||||
return config
|
||||
_config = yaml.safe_load(f)
|
||||
return _config
|
||||
|
||||
|
||||
def create_engine_from_config(config):
|
||||
mysql = config['mysql']
|
||||
connection_string = f"mysql+pymysql://{mysql['user']}:{mysql['password']}@{mysql['host']}:{mysql['port']}/{mysql['database']}?charset={mysql['charset']}&use_unicode=1"
|
||||
return create_engine(connection_string)
|
||||
def get_engine():
|
||||
"""获取单例数据库引擎"""
|
||||
global _engine
|
||||
if _engine is None:
|
||||
config = load_config()
|
||||
mysql = config['mysql']
|
||||
connection_string = f"mysql+pymysql://{mysql['user']}:{mysql['password']}@{mysql['host']}:{mysql['port']}/{mysql['database']}?charset={mysql['charset']}&use_unicode=1"
|
||||
_engine = create_engine(connection_string)
|
||||
return _engine
|
Loading…
Reference in New Issue
Block a user