🔄 refactor(data_downloader): 优化数据库引擎创建逻辑,使用单例模式减少重复连接创建
🔄 refactor(utils): 实现配置加载的缓存,避免重复读取配置文件
This commit is contained in:
parent
e1c47616dd
commit
19718bd59f
@ -10,7 +10,7 @@ from functools import partial
|
|||||||
from sqlalchemy import create_engine, text
|
from sqlalchemy import create_engine, text
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
from utils import load_config, create_engine_from_config
|
from utils import load_config, get_engine
|
||||||
|
|
||||||
|
|
||||||
def create_metadata_table(engine):
|
def create_metadata_table(engine):
|
||||||
@ -351,7 +351,7 @@ def process_stock_batch(batch, config, update_mode, start_date, end_date):
|
|||||||
dict: 包含处理结果统计信息的字典
|
dict: 包含处理结果统计信息的字典
|
||||||
"""
|
"""
|
||||||
# 为当前进程创建新的数据库连接和API连接
|
# 为当前进程创建新的数据库连接和API连接
|
||||||
engine = create_engine_from_config(config)
|
engine = get_engine()
|
||||||
ts.set_token(config['tushare_token'])
|
ts.set_token(config['tushare_token'])
|
||||||
pro = ts.pro_api()
|
pro = ts.pro_api()
|
||||||
|
|
||||||
@ -567,7 +567,7 @@ def perform_full_update(start_year=2020, processes=4, resume=False):
|
|||||||
config = load_config()
|
config = load_config()
|
||||||
|
|
||||||
# 创建数据库引擎
|
# 创建数据库引擎
|
||||||
engine = create_engine_from_config(config)
|
engine = get_engine()
|
||||||
|
|
||||||
# 设置Tushare API
|
# 设置Tushare API
|
||||||
ts.set_token(config['tushare_token'])
|
ts.set_token(config['tushare_token'])
|
||||||
@ -767,7 +767,7 @@ def perform_incremental_update(processes=4):
|
|||||||
config = load_config()
|
config = load_config()
|
||||||
|
|
||||||
# 创建数据库引擎
|
# 创建数据库引擎
|
||||||
engine = create_engine_from_config(config)
|
engine = get_engine()
|
||||||
|
|
||||||
# 设置Tushare API
|
# 设置Tushare API
|
||||||
ts.set_token(config['tushare_token'])
|
ts.set_token(config['tushare_token'])
|
||||||
@ -787,7 +787,7 @@ def perform_incremental_update(processes=4):
|
|||||||
FROM
|
FROM
|
||||||
stock_metadata
|
stock_metadata
|
||||||
WHERE
|
WHERE
|
||||||
list_status = 'L'
|
list_status = 'L' AND status = 1
|
||||||
"""
|
"""
|
||||||
result = conn.execute(text(query))
|
result = conn.execute(text(query))
|
||||||
for row in result:
|
for row in result:
|
||||||
@ -807,12 +807,8 @@ def perform_incremental_update(processes=4):
|
|||||||
if not latest_update or last_full_update > latest_update:
|
if not latest_update or last_full_update > latest_update:
|
||||||
latest_update = last_full_update
|
latest_update = last_full_update
|
||||||
|
|
||||||
# 如果有更新日期,使用其后一天作为起始日期
|
# 使用全量/增量最晚的一个时间
|
||||||
if latest_update:
|
start_date = latest_update.strftime('%Y%m%d')
|
||||||
start_date = (latest_update + timedelta(days=1)).strftime('%Y%m%d')
|
|
||||||
else:
|
|
||||||
# 如果没有任何更新记录,默认从30天前开始
|
|
||||||
start_date = (datetime.now() - timedelta(days=30)).strftime('%Y%m%d')
|
|
||||||
|
|
||||||
# 如果起始日期大于等于结束日期,则跳过此股票
|
# 如果起始日期大于等于结束日期,则跳过此股票
|
||||||
if start_date >= end_date:
|
if start_date >= end_date:
|
||||||
@ -829,9 +825,6 @@ def perform_incremental_update(processes=4):
|
|||||||
|
|
||||||
print(f"共找到 {len(stocks_to_update)} 只股票需要增量更新")
|
print(f"共找到 {len(stocks_to_update)} 只股票需要增量更新")
|
||||||
|
|
||||||
# 当前更新时间
|
|
||||||
update_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
|
||||||
|
|
||||||
# 将股票列表分成多个批次
|
# 将股票列表分成多个批次
|
||||||
batches = split_list_into_chunks(stocks_to_update, processes)
|
batches = split_list_into_chunks(stocks_to_update, processes)
|
||||||
|
|
||||||
@ -843,8 +836,7 @@ def perform_incremental_update(processes=4):
|
|||||||
config=config,
|
config=config,
|
||||||
update_mode='incremental',
|
update_mode='incremental',
|
||||||
start_date=None, # 增量模式下使用每个股票自己的起始日期
|
start_date=None, # 增量模式下使用每个股票自己的起始日期
|
||||||
end_date=end_date,
|
end_date=end_date
|
||||||
update_time=update_time
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 启动多进程处理并收集结果
|
# 启动多进程处理并收集结果
|
||||||
@ -1019,7 +1011,7 @@ def main():
|
|||||||
# 加载配置
|
# 加载配置
|
||||||
config = load_config()
|
config = load_config()
|
||||||
# 创建数据库引擎
|
# 创建数据库引擎
|
||||||
engine = create_engine_from_config(config)
|
engine = get_engine()
|
||||||
# 设置Tushare API
|
# 设置Tushare API
|
||||||
ts.set_token(config['tushare_token'])
|
ts.set_token(config['tushare_token'])
|
||||||
pro = ts.pro_api()
|
pro = ts.pro_api()
|
||||||
|
25
utils.py
25
utils.py
@ -6,8 +6,16 @@ from datetime import datetime, timedelta
|
|||||||
|
|
||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine
|
||||||
|
|
||||||
|
# 模块级单例
|
||||||
|
_config = None
|
||||||
|
_engine = None
|
||||||
|
|
||||||
|
|
||||||
def load_config():
|
def load_config():
|
||||||
|
global _config
|
||||||
|
if _config is not None:
|
||||||
|
return _config
|
||||||
|
|
||||||
config_path = 'config.yaml'
|
config_path = 'config.yaml'
|
||||||
if not os.path.exists(config_path):
|
if not os.path.exists(config_path):
|
||||||
config = {
|
config = {
|
||||||
@ -27,11 +35,16 @@ def load_config():
|
|||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
with open(config_path, 'r') as f:
|
with open(config_path, 'r') as f:
|
||||||
config = yaml.safe_load(f)
|
_config = yaml.safe_load(f)
|
||||||
return config
|
return _config
|
||||||
|
|
||||||
|
|
||||||
def create_engine_from_config(config):
|
def get_engine():
|
||||||
mysql = config['mysql']
|
"""获取单例数据库引擎"""
|
||||||
connection_string = f"mysql+pymysql://{mysql['user']}:{mysql['password']}@{mysql['host']}:{mysql['port']}/{mysql['database']}?charset={mysql['charset']}&use_unicode=1"
|
global _engine
|
||||||
return create_engine(connection_string)
|
if _engine is None:
|
||||||
|
config = load_config()
|
||||||
|
mysql = config['mysql']
|
||||||
|
connection_string = f"mysql+pymysql://{mysql['user']}:{mysql['password']}@{mysql['host']}:{mysql['port']}/{mysql['database']}?charset={mysql['charset']}&use_unicode=1"
|
||||||
|
_engine = create_engine(connection_string)
|
||||||
|
return _engine
|
Loading…
Reference in New Issue
Block a user