♻️ refactor(database): 简化数据库表命名机制,移除表名映射层

This commit is contained in:
Qihang Zhang 2025-04-19 17:12:01 +08:00
parent c0468c9f71
commit 2c0eebf923
3 changed files with 24 additions and 36 deletions

View File

@ -41,10 +41,7 @@ class ConfigManager:
'tushare_token': 'xxxxxxxxxxx', 'tushare_token': 'xxxxxxxxxxx',
'sqlite': { 'sqlite': {
'path': './data/tushare_data.db', 'path': './data/tushare_data.db',
'database_name': 'tushare_data', 'database_name': 'tushare_data'
'table_name': {
'moneyflow_ind_dc': 'moneyflow_ind_dc'
}
}, },
'log': { 'log': {
'level': 'INFO', 'level': 'INFO',

View File

@ -63,7 +63,7 @@ class DataFetcher:
pandas.DataFrame: 请求的数据 pandas.DataFrame: 请求的数据
""" """
# 使用api_name作为表key查询表名 # 使用api_name作为表key查询表名
table_key = api_name table_name = api_name
# 确保self.pro中存在对应的API方法 # 确保self.pro中存在对应的API方法
if not hasattr(self.pro, api_name): if not hasattr(self.pro, api_name):
@ -76,7 +76,7 @@ class DataFetcher:
# 确定需要获取的日期 # 确定需要获取的日期
if not force_update: if not force_update:
# 从数据库获取已有的交易日期 # 从数据库获取已有的交易日期
existing_dates = self.db_manager.get_existing_trade_dates(table_key=table_key) existing_dates = self.db_manager.get_existing_trade_dates(table_name=table_name)
# 筛选出需要新获取的日期 # 筛选出需要新获取的日期
dates_to_fetch = [date for date in all_trade_dates if date not in existing_dates] dates_to_fetch = [date for date in all_trade_dates if date not in existing_dates]
else: else:
@ -86,7 +86,7 @@ class DataFetcher:
if not dates_to_fetch: if not dates_to_fetch:
self.logger.info("所有数据已在数据库中,无需更新") self.logger.info("所有数据已在数据库中,无需更新")
return self.db_manager.load_df_from_db(table_key=table_key) return self.db_manager.load_df_from_db(table_name=table_name)
self.logger.info(f"需要获取 {len(dates_to_fetch)} 个交易日的数据") self.logger.info(f"需要获取 {len(dates_to_fetch)} 个交易日的数据")
@ -107,7 +107,7 @@ class DataFetcher:
# 当累积的数据量达到batch_size时进行一次批量写入 # 当累积的数据量达到batch_size时进行一次批量写入
if total_rows >= batch_size: if total_rows >= batch_size:
self._process_batch(temp_batch, table_key, force_update) self._process_batch(temp_batch, table_name, force_update)
self.logger.info(f"已处理 {total_rows} 行数据") self.logger.info(f"已处理 {total_rows} 行数据")
# 重置临时批次 # 重置临时批次
temp_batch = [] temp_batch = []
@ -120,18 +120,18 @@ class DataFetcher:
# 处理剩余的数据 # 处理剩余的数据
if temp_batch: if temp_batch:
self._process_batch(temp_batch, table_key, force_update) self._process_batch(temp_batch, table_name, force_update)
self.logger.info(f"已处理剩余 {total_rows} 行数据") self.logger.info(f"已处理剩余 {total_rows} 行数据")
self.logger.info("数据获取与处理完成") self.logger.info("数据获取与处理完成")
return self.db_manager.load_df_from_db(table_key=table_key) return self.db_manager.load_df_from_db(table_name=table_name)
def _process_batch(self, batch_dfs, table_key, force_update): def _process_batch(self, batch_dfs, table_name, force_update):
""" """
处理一批数据 处理一批数据
参数: 参数:
batch_dfs (list): DataFrame列表 batch_dfs (list): DataFrame列表
table_key (str): 数据库表名 table_name (str): 数据库表名
force_update (bool): 是否强制更新 force_update (bool): 是否强制更新
""" """
if not batch_dfs: if not batch_dfs:
@ -145,10 +145,10 @@ class DataFetcher:
current_dates = batch_df['trade_date'].unique().tolist() current_dates = batch_df['trade_date'].unique().tolist()
# 删除这些日期的现有数据 # 删除这些日期的现有数据
self.db_manager.delete_existing_data_by_dates(table_key, current_dates) self.db_manager.delete_existing_data_by_dates(table_name, current_dates)
# 插入新数据 # 插入新数据
self.db_manager.save_df_to_db(batch_df, table_key=table_key, if_exists='append') self.db_manager.save_df_to_db(batch_df, table_name=table_name, if_exists='append')
else: else:
# 普通追加模式 # 普通追加模式
self.db_manager.save_df_to_db(batch_df, table_key=table_key, if_exists='append') self.db_manager.save_df_to_db(batch_df, table_name=table_name, if_exists='append')

View File

@ -32,37 +32,31 @@ class DatabaseManager:
self._engine = create_engine(f'sqlite:///{db_path}', echo=False) self._engine = create_engine(f'sqlite:///{db_path}', echo=False)
return self._engine return self._engine
def get_table_name(self, key): def table_exists(self, table_name):
"""根据表键名获取实际表名"""
return self.config.get(f'sqlite.table_name.{key}', key)
def table_exists(self, table_key):
""" """
检查表是否存在 检查表是否存在
参数: 参数:
table_key (str): 数据表 table_name (str): 数据表名称
返回: 返回:
bool: 表是否存在 bool: 表是否存在
""" """
table_name = self.get_table_name(table_key)
engine = self.get_engine() engine = self.get_engine()
inspector = inspect(engine) inspector = inspect(engine)
return inspector.has_table(table_name) return inspector.has_table(table_name)
def get_existing_trade_dates(self, table_key): def get_existing_trade_dates(self, table_name):
""" """
从数据库中获取已有的交易日期 从数据库中获取已有的交易日期
参数: 参数:
table_key (str): 数据表 table_name (str): 数据表
返回: 返回:
set: 已存在于数据库中的交易日期集合 set: 已存在于数据库中的交易日期集合
""" """
# 先检查表是否存在 # 先检查表是否存在
if not self.table_exists(table_key): if not self.table_exists(table_name):
self.logger.debug(f"'{self.get_table_name(table_key)}' 不存在") self.logger.debug(f"'{table_name}' 不存在")
return set() return set()
table_name = self.get_table_name(table_key)
engine = self.get_engine() engine = self.get_engine()
query = f"SELECT DISTINCT trade_date FROM {table_name}" query = f"SELECT DISTINCT trade_date FROM {table_name}"
try: try:
@ -74,16 +68,15 @@ class DatabaseManager:
self.logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}") self.logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
return set() return set()
def load_df_from_db(self, table_key, conditions=None): def load_df_from_db(self, table_name, conditions=None):
""" """
从数据库中加载数据 从数据库中加载数据
参数: 参数:
table_key (str): 表键名 table_name (str): 数据表名称
conditions (str): 过滤条件 "trade_date > '20230101'" conditions (str): 过滤条件 "trade_date > '20230101'"
返回: 返回:
pandas.DataFrame: 查询结果 pandas.DataFrame: 查询结果
""" """
table_name = self.get_table_name(table_key)
engine = self.get_engine() engine = self.get_engine()
query = f"SELECT * FROM {table_name}" query = f"SELECT * FROM {table_name}"
if conditions: if conditions:
@ -95,12 +88,12 @@ class DatabaseManager:
self.logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}") self.logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
return pd.DataFrame() return pd.DataFrame()
def save_df_to_db(self, df, table_key, if_exists='append'): def save_df_to_db(self, df, table_name, if_exists='append'):
""" """
保存DataFrame到数据库 保存DataFrame到数据库
参数: 参数:
df (pandas.DataFrame): 要保存的数据 df (pandas.DataFrame): 要保存的数据
table_key (str): 表键名 table_name (str): 数据表名称
if_exists (str): 如果表存在时的操作: 'fail', 'replace', 'append' if_exists (str): 如果表存在时的操作: 'fail', 'replace', 'append'
返回: 返回:
bool: 操作是否成功 bool: 操作是否成功
@ -109,7 +102,6 @@ class DatabaseManager:
self.logger.warning("警告: 尝试保存空的DataFrame到数据库") self.logger.warning("警告: 尝试保存空的DataFrame到数据库")
return False return False
table_name = self.get_table_name(table_key)
engine = self.get_engine() engine = self.get_engine()
try: try:
df.to_sql(table_name, engine, if_exists=if_exists, index=False) df.to_sql(table_name, engine, if_exists=if_exists, index=False)
@ -119,12 +111,12 @@ class DatabaseManager:
self.logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}") self.logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
return False return False
def delete_existing_data_by_dates(self, table_key, trade_dates): def delete_existing_data_by_dates(self, table_name, trade_dates):
""" """
从数据库表中删除指定交易日期的数据 从数据库表中删除指定交易日期的数据
参数: 参数:
table_key (str): 数据表 table_name (str): 数据表
trade_dates (list): 需要删除的交易日期列表 trade_dates (list): 需要删除的交易日期列表
返回: 返回:
@ -133,7 +125,6 @@ class DatabaseManager:
if not trade_dates: if not trade_dates:
return True # 如果没有日期需要删除,认为操作成功 return True # 如果没有日期需要删除,认为操作成功
table_name = self.get_table_name(table_key)
engine = self.get_engine() engine = self.get_engine()
try: try: