120 lines
4.0 KiB
Python
120 lines
4.0 KiB
Python
import os
|
||
import traceback
|
||
|
||
import pandas as pd
|
||
from sqlalchemy import create_engine, text, inspect
|
||
|
||
from config_manager import get_config_manager
|
||
from logger import get_logger
|
||
|
||
|
||
class DatabaseManager:
|
||
"""
|
||
数据库管理类,负责数据的存储、查询和管理
|
||
"""
|
||
|
||
def __init__(self):
|
||
self.config = get_config_manager()
|
||
self._engine = None
|
||
self.logger = get_logger()
|
||
|
||
def get_engine(self):
|
||
"""获取SQLite数据库引擎,如果不存在则创建"""
|
||
if self._engine is not None:
|
||
return self._engine
|
||
|
||
db_path = self.config.get('sqlite.path')
|
||
|
||
# 确保数据库目录存在
|
||
os.makedirs(os.path.dirname(db_path), exist_ok=True)
|
||
|
||
# 创建SQLite数据库引擎
|
||
self._engine = create_engine(f'sqlite:///{db_path}', echo=False)
|
||
return self._engine
|
||
|
||
def get_table_name(self, key):
|
||
"""根据表键名获取实际表名"""
|
||
return self.config.get(f'sqlite.table_name.{key}', key)
|
||
|
||
def table_exists(self, table_key):
|
||
"""
|
||
检查表是否存在
|
||
参数:
|
||
table_key (str): 数据表键名
|
||
返回:
|
||
bool: 表是否存在
|
||
"""
|
||
table_name = self.get_table_name(table_key)
|
||
engine = self.get_engine()
|
||
inspector = inspect(engine)
|
||
return inspector.has_table(table_name)
|
||
|
||
def get_existing_trade_dates(self, table_key):
|
||
"""
|
||
从数据库中获取已有的交易日期
|
||
参数:
|
||
table_key (str): 数据表键名
|
||
返回:
|
||
set: 已存在于数据库中的交易日期集合
|
||
"""
|
||
# 先检查表是否存在
|
||
if not self.table_exists(table_key):
|
||
self.logger.debug(f"表 '{self.get_table_name(table_key)}' 不存在")
|
||
return set()
|
||
|
||
table_name = self.get_table_name(table_key)
|
||
engine = self.get_engine()
|
||
query = f"SELECT DISTINCT trade_date FROM {table_name}"
|
||
try:
|
||
with engine.connect() as connection:
|
||
result = connection.execute(text(query))
|
||
return {row[0] for row in result}
|
||
except Exception as e:
|
||
self.logger.error(f"获取已存在交易日期时出错: {e}")
|
||
self.logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
|
||
return set()
|
||
|
||
def load_df_from_db(self, table_key, conditions=None):
|
||
"""
|
||
从数据库中加载数据
|
||
参数:
|
||
table_key (str): 表键名
|
||
conditions (str): 过滤条件,如 "trade_date > '20230101'"
|
||
返回:
|
||
pandas.DataFrame: 查询结果
|
||
"""
|
||
table_name = self.get_table_name(table_key)
|
||
engine = self.get_engine()
|
||
query = f"SELECT * FROM {table_name}"
|
||
if conditions:
|
||
query += f" WHERE {conditions}"
|
||
try:
|
||
return pd.read_sql(query, engine)
|
||
except Exception as e:
|
||
self.logger.error(f"从数据库加载数据时出错: {e}")
|
||
self.logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
|
||
return pd.DataFrame()
|
||
|
||
def save_df_to_db(self, df, table_key, if_exists='append'):
|
||
"""
|
||
保存DataFrame到数据库
|
||
参数:
|
||
df (pandas.DataFrame): 要保存的数据
|
||
table_key (str): 表键名
|
||
if_exists (str): 如果表存在时的操作: 'fail', 'replace', 或 'append'
|
||
返回:
|
||
bool: 操作是否成功
|
||
"""
|
||
if df.empty:
|
||
self.logger.warning("警告: 尝试保存空的DataFrame到数据库")
|
||
return False
|
||
|
||
table_name = self.get_table_name(table_key)
|
||
engine = self.get_engine()
|
||
try:
|
||
df.to_sql(table_name, engine, if_exists=if_exists, index=False)
|
||
return True
|
||
except Exception as e:
|
||
self.logger.error(f"保存数据到数据库时出错: {e}")
|
||
self.logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
|
||
return False |