backtrader/database_manager.py

147 lines
4.8 KiB
Python
Raw Normal View History

import os
import traceback
import pandas as pd
from sqlalchemy import create_engine, text, inspect
from config_manager import get_config_manager
from logger import get_logger
class DatabaseManager:
"""
数据库管理类负责数据的存储查询和管理
"""
def __init__(self):
self.config = get_config_manager()
self._engine = None
self.logger = get_logger()
def get_engine(self):
"""获取SQLite数据库引擎如果不存在则创建"""
if self._engine is not None:
return self._engine
db_path = self.config.get('sqlite.path')
# 确保数据库目录存在
os.makedirs(os.path.dirname(db_path), exist_ok=True)
# 创建SQLite数据库引擎
self._engine = create_engine(f'sqlite:///{db_path}', echo=False)
return self._engine
def table_exists(self, table_name):
"""
检查表是否存在
参数:
table_name (str): 数据表名称
返回:
bool: 表是否存在
"""
engine = self.get_engine()
inspector = inspect(engine)
return inspector.has_table(table_name)
def get_existing_trade_dates(self, table_name):
"""
从数据库中获取已有的交易日期
参数:
table_name (str): 数据表名称
返回:
set: 已存在于数据库中的交易日期集合
"""
# 先检查表是否存在
if not self.table_exists(table_name):
self.logger.debug(f"'{table_name}' 不存在")
return set()
engine = self.get_engine()
query = f"SELECT DISTINCT trade_date FROM {table_name}"
try:
with engine.connect() as connection:
result = connection.execute(text(query))
return {row[0] for row in result}
except Exception as e:
self.logger.error(f"获取已存在交易日期时出错: {e}")
self.logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
return set()
def load_df_from_db(self, table_name, conditions=None):
"""
从数据库中加载数据
参数:
table_name (str): 数据表名称
conditions (str): 过滤条件 "trade_date > '20230101'"
返回:
pandas.DataFrame: 查询结果
"""
engine = self.get_engine()
query = f"SELECT * FROM {table_name}"
if conditions:
query += f" WHERE {conditions}"
try:
return pd.read_sql(query, engine)
except Exception as e:
self.logger.error(f"从数据库加载数据时出错: {e}")
self.logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
return pd.DataFrame()
def save_df_to_db(self, df, table_name, if_exists='append'):
"""
保存DataFrame到数据库
参数:
df (pandas.DataFrame): 要保存的数据
table_name (str): 数据表名称
if_exists (str): 如果表存在时的操作: 'fail', 'replace', 'append'
返回:
bool: 操作是否成功
"""
if df.empty:
self.logger.warning("警告: 尝试保存空的DataFrame到数据库")
return False
engine = self.get_engine()
try:
df.to_sql(table_name, engine, if_exists=if_exists, index=False)
return True
except Exception as e:
self.logger.error(f"保存数据到数据库时出错: {e}")
self.logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
return False
def delete_existing_data_by_dates(self, table_name, trade_dates):
"""
从数据库表中删除指定交易日期的数据
参数:
table_name (str): 数据表名称
trade_dates (list): 需要删除的交易日期列表
返回:
bool: 操作是否成功
"""
if not trade_dates:
return True # 如果没有日期需要删除,认为操作成功
engine = self.get_engine()
try:
# 将日期列表转换为SQL安全的格式
date_strings = [f"'{date}'" for date in trade_dates]
dates_clause = ", ".join(date_strings)
delete_query = f"DELETE FROM {table_name} WHERE trade_date IN ({dates_clause})"
with engine.connect() as connection:
connection.execute(text(delete_query))
connection.commit()
self.logger.info(f"成功从表 '{table_name}' 中删除 {len(trade_dates)} 个日期的数据")
return True
except Exception as e:
self.logger.error(f"删除表 '{table_name}' 中的数据时出错: {e}")
self.logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
return False