import os import traceback import pandas as pd from sqlalchemy import create_engine, text, inspect from config_manager import get_config_manager from logger_manager import get_logger class DatabaseManager: """ 数据库管理类,负责数据的存储、查询和管理 """ def __init__(self): self.config = get_config_manager() self._engine = None self.logger = get_logger() def get_engine(self): """获取SQLite数据库引擎,如果不存在则创建""" if self._engine is not None: return self._engine db_path = self.config.get('sqlite.path') # 确保数据库目录存在 os.makedirs(os.path.dirname(db_path), exist_ok=True) # 创建SQLite数据库引擎 self._engine = create_engine(f'sqlite:///{db_path}', echo=False) return self._engine def table_exists(self, table_name): """ 检查表是否存在 参数: table_name (str): 数据表名称 返回: bool: 表是否存在 """ engine = self.get_engine() inspector = inspect(engine) return inspector.has_table(table_name) def load_df_from_db(self, table_name, conditions=None): """ 从数据库中加载数据 参数: table_name (str): 数据表名称 conditions (str): 过滤条件,如 "trade_date > '20230101'" 返回: pandas.DataFrame: 查询结果 """ engine = self.get_engine() query = f"SELECT * FROM {table_name}" if conditions: query += f" WHERE {conditions}" try: return pd.read_sql(query, engine) except Exception as e: self.logger.error(f"从数据库加载数据时出错: {e}") self.logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}") return pd.DataFrame() def save_df_to_db(self, df, table_name, if_exists='append', force_update=False): """ 保存DataFrame到数据库 参数: df (pandas.DataFrame): 要保存的数据 table_name (str): 数据表名称 if_exists (str): 如果表存在时的操作: 'fail', 'replace', 或 'append' force_update (bool): 是否强制更新(会先删除相同日期的数据再插入) 返回: bool: 操作是否成功 """ if df.empty: self.logger.warning("警告: 尝试保存空的DataFrame到数据库") return False engine = self.get_engine() try: if force_update and 'trade_date' in df.columns: # 强制更新模式:先删除当前批次涉及的日期数据,然后插入新数据 current_dates = df['trade_date'].unique().tolist() # 删除这些日期的现有数据 self.delete_existing_data_by_dates(table_name, current_dates) # 插入新数据(强制更新时始终使用append,因为已经删除了相关数据) df.to_sql(table_name, engine, if_exists='append', index=False) else: # 普通模式 df.to_sql(table_name, engine, if_exists=if_exists, index=False) return True except Exception as e: self.logger.error(f"保存数据到数据库时出错: {e}") self.logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}") return False def delete_existing_data_by_dates(self, table_name, trade_dates): """ 从数据库表中删除指定交易日期的数据 参数: table_name (str): 数据表名称 trade_dates (list): 需要删除的交易日期列表 返回: bool: 操作是否成功 """ if not trade_dates: return True # 如果没有日期需要删除,认为操作成功 engine = self.get_engine() try: # 将日期列表转换为SQL安全的格式 date_strings = [f"'{date}'" for date in trade_dates] dates_clause = ", ".join(date_strings) delete_query = f"DELETE FROM {table_name} WHERE trade_date IN ({dates_clause})" with engine.connect() as connection: connection.execute(text(delete_query)) connection.commit() self.logger.info(f"成功从表 '{table_name}' 中删除 {len(trade_dates)} 个日期的数据") return True except Exception as e: self.logger.error(f"删除表 '{table_name}' 中的数据时出错: {e}") self.logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}") return False def query(self, sql_query, params=None): """ 执行自定义SQL查询 参数: sql_query (str): 要执行的SQL查询语句 params (dict, 可选): SQL参数化查询的参数 返回: pandas.DataFrame 或 bool: 对于SELECT查询返回DataFrame,其他查询返回是否执行成功 """ engine = self.get_engine() try: # 判断是否是SELECT查询 is_select = sql_query.strip().upper().startswith("SELECT") if is_select: # 对于SELECT查询,返回DataFrame return pd.read_sql(sql_query, engine, params=params) else: # 对于非SELECT查询,执行并返回成功状态 with engine.connect() as connection: connection.execute(text(sql_query), params) connection.commit() self.logger.info(f"成功执行SQL查询") return True except Exception as e: self.logger.error(f"执行SQL查询时出错: {e}") self.logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}") if is_select: return pd.DataFrame() return False