diff --git a/config_manager.py b/config_manager.py index e16036e..4dbc381 100644 --- a/config_manager.py +++ b/config_manager.py @@ -41,10 +41,7 @@ class ConfigManager: 'tushare_token': 'xxxxxxxxxxx', 'sqlite': { 'path': './data/tushare_data.db', - 'database_name': 'tushare_data', - 'table_name': { - 'moneyflow_ind_dc': 'moneyflow_ind_dc' - } + 'database_name': 'tushare_data' }, 'log': { 'level': 'INFO', diff --git a/data_fetcher.py b/data_fetcher.py index 1ae07da..3be4b2c 100644 --- a/data_fetcher.py +++ b/data_fetcher.py @@ -63,7 +63,7 @@ class DataFetcher: pandas.DataFrame: 请求的数据 """ # 使用api_name作为表key查询表名 - table_key = api_name + table_name = api_name # 确保self.pro中存在对应的API方法 if not hasattr(self.pro, api_name): @@ -76,7 +76,7 @@ class DataFetcher: # 确定需要获取的日期 if not force_update: # 从数据库获取已有的交易日期 - existing_dates = self.db_manager.get_existing_trade_dates(table_key=table_key) + existing_dates = self.db_manager.get_existing_trade_dates(table_name=table_name) # 筛选出需要新获取的日期 dates_to_fetch = [date for date in all_trade_dates if date not in existing_dates] else: @@ -86,7 +86,7 @@ class DataFetcher: if not dates_to_fetch: self.logger.info("所有数据已在数据库中,无需更新") - return self.db_manager.load_df_from_db(table_key=table_key) + return self.db_manager.load_df_from_db(table_name=table_name) self.logger.info(f"需要获取 {len(dates_to_fetch)} 个交易日的数据") @@ -107,7 +107,7 @@ class DataFetcher: # 当累积的数据量达到batch_size时,进行一次批量写入 if total_rows >= batch_size: - self._process_batch(temp_batch, table_key, force_update) + self._process_batch(temp_batch, table_name, force_update) self.logger.info(f"已处理 {total_rows} 行数据") # 重置临时批次 temp_batch = [] @@ -120,18 +120,18 @@ class DataFetcher: # 处理剩余的数据 if temp_batch: - self._process_batch(temp_batch, table_key, force_update) + self._process_batch(temp_batch, table_name, force_update) self.logger.info(f"已处理剩余 {total_rows} 行数据") self.logger.info("数据获取与处理完成") - return self.db_manager.load_df_from_db(table_key=table_key) + return self.db_manager.load_df_from_db(table_name=table_name) - def _process_batch(self, batch_dfs, table_key, force_update): + def _process_batch(self, batch_dfs, table_name, force_update): """ 处理一批数据 参数: batch_dfs (list): DataFrame列表 - table_key (str): 数据库表名 + table_name (str): 数据库表名称 force_update (bool): 是否强制更新 """ if not batch_dfs: @@ -145,10 +145,10 @@ class DataFetcher: current_dates = batch_df['trade_date'].unique().tolist() # 删除这些日期的现有数据 - self.db_manager.delete_existing_data_by_dates(table_key, current_dates) + self.db_manager.delete_existing_data_by_dates(table_name, current_dates) # 插入新数据 - self.db_manager.save_df_to_db(batch_df, table_key=table_key, if_exists='append') + self.db_manager.save_df_to_db(batch_df, table_name=table_name, if_exists='append') else: # 普通追加模式 - self.db_manager.save_df_to_db(batch_df, table_key=table_key, if_exists='append') + self.db_manager.save_df_to_db(batch_df, table_name=table_name, if_exists='append') diff --git a/database_manager.py b/database_manager.py index 78912f1..55b8693 100644 --- a/database_manager.py +++ b/database_manager.py @@ -32,37 +32,31 @@ class DatabaseManager: self._engine = create_engine(f'sqlite:///{db_path}', echo=False) return self._engine - def get_table_name(self, key): - """根据表键名获取实际表名""" - return self.config.get(f'sqlite.table_name.{key}', key) - - def table_exists(self, table_key): + def table_exists(self, table_name): """ 检查表是否存在 参数: - table_key (str): 数据表键名 + table_name (str): 数据表名称 返回: bool: 表是否存在 """ - table_name = self.get_table_name(table_key) engine = self.get_engine() inspector = inspect(engine) return inspector.has_table(table_name) - def get_existing_trade_dates(self, table_key): + def get_existing_trade_dates(self, table_name): """ 从数据库中获取已有的交易日期 参数: - table_key (str): 数据表键名 + table_name (str): 数据表名称 返回: set: 已存在于数据库中的交易日期集合 """ # 先检查表是否存在 - if not self.table_exists(table_key): - self.logger.debug(f"表 '{self.get_table_name(table_key)}' 不存在") + if not self.table_exists(table_name): + self.logger.debug(f"表 '{table_name}' 不存在") return set() - table_name = self.get_table_name(table_key) engine = self.get_engine() query = f"SELECT DISTINCT trade_date FROM {table_name}" try: @@ -74,16 +68,15 @@ class DatabaseManager: self.logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}") return set() - def load_df_from_db(self, table_key, conditions=None): + def load_df_from_db(self, table_name, conditions=None): """ 从数据库中加载数据 参数: - table_key (str): 表键名 + table_name (str): 数据表名称 conditions (str): 过滤条件,如 "trade_date > '20230101'" 返回: pandas.DataFrame: 查询结果 """ - table_name = self.get_table_name(table_key) engine = self.get_engine() query = f"SELECT * FROM {table_name}" if conditions: @@ -95,12 +88,12 @@ class DatabaseManager: self.logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}") return pd.DataFrame() - def save_df_to_db(self, df, table_key, if_exists='append'): + def save_df_to_db(self, df, table_name, if_exists='append'): """ 保存DataFrame到数据库 参数: df (pandas.DataFrame): 要保存的数据 - table_key (str): 表键名 + table_name (str): 数据表名称 if_exists (str): 如果表存在时的操作: 'fail', 'replace', 或 'append' 返回: bool: 操作是否成功 @@ -109,7 +102,6 @@ class DatabaseManager: self.logger.warning("警告: 尝试保存空的DataFrame到数据库") return False - table_name = self.get_table_name(table_key) engine = self.get_engine() try: df.to_sql(table_name, engine, if_exists=if_exists, index=False) @@ -119,12 +111,12 @@ class DatabaseManager: self.logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}") return False - def delete_existing_data_by_dates(self, table_key, trade_dates): + def delete_existing_data_by_dates(self, table_name, trade_dates): """ 从数据库表中删除指定交易日期的数据 参数: - table_key (str): 数据表键名 + table_name (str): 数据表名称 trade_dates (list): 需要删除的交易日期列表 返回: @@ -133,7 +125,6 @@ class DatabaseManager: if not trade_dates: return True # 如果没有日期需要删除,认为操作成功 - table_name = self.get_table_name(table_key) engine = self.get_engine() try: