backtrader/database_manager.py

168 lines
5.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import traceback
import pandas as pd
from sqlalchemy import create_engine, text, inspect
from config_manager import get_config_manager
from logger_manager import get_logger
class DatabaseManager:
"""
数据库管理类,负责数据的存储、查询和管理
"""
def __init__(self):
self.config = get_config_manager()
self._engine = None
self.logger = get_logger()
def get_engine(self):
"""获取SQLite数据库引擎如果不存在则创建"""
if self._engine is not None:
return self._engine
db_path = self.config.get('sqlite.path')
# 确保数据库目录存在
os.makedirs(os.path.dirname(db_path), exist_ok=True)
# 创建SQLite数据库引擎
self._engine = create_engine(f'sqlite:///{db_path}', echo=False)
return self._engine
def table_exists(self, table_name):
"""
检查表是否存在
参数:
table_name (str): 数据表名称
返回:
bool: 表是否存在
"""
engine = self.get_engine()
inspector = inspect(engine)
return inspector.has_table(table_name)
def load_df_from_db(self, table_name, conditions=None):
"""
从数据库中加载数据
参数:
table_name (str): 数据表名称
conditions (str): 过滤条件,如 "trade_date > '20230101'"
返回:
pandas.DataFrame: 查询结果
"""
engine = self.get_engine()
query = f"SELECT * FROM {table_name}"
if conditions:
query += f" WHERE {conditions}"
try:
return pd.read_sql(query, engine)
except Exception as e:
self.logger.error(f"从数据库加载数据时出错: {e}")
self.logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
return pd.DataFrame()
def save_df_to_db(self, df, table_name, if_exists='append', force_update=False):
"""
保存DataFrame到数据库
参数:
df (pandas.DataFrame): 要保存的数据
table_name (str): 数据表名称
if_exists (str): 如果表存在时的操作: 'fail', 'replace', 或 'append'
force_update (bool): 是否强制更新(会先删除相同日期的数据再插入)
返回:
bool: 操作是否成功
"""
if df.empty:
self.logger.warning("警告: 尝试保存空的DataFrame到数据库")
return False
engine = self.get_engine()
try:
if force_update and 'trade_date' in df.columns:
# 强制更新模式:先删除当前批次涉及的日期数据,然后插入新数据
current_dates = df['trade_date'].unique().tolist()
# 删除这些日期的现有数据
self.delete_existing_data_by_dates(table_name, current_dates)
# 插入新数据强制更新时始终使用append因为已经删除了相关数据
df.to_sql(table_name, engine, if_exists='append', index=False)
else:
# 普通模式
df.to_sql(table_name, engine, if_exists=if_exists, index=False)
return True
except Exception as e:
self.logger.error(f"保存数据到数据库时出错: {e}")
self.logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
return False
def delete_existing_data_by_dates(self, table_name, trade_dates):
"""
从数据库表中删除指定交易日期的数据
参数:
table_name (str): 数据表名称
trade_dates (list): 需要删除的交易日期列表
返回:
bool: 操作是否成功
"""
if not trade_dates:
return True # 如果没有日期需要删除,认为操作成功
engine = self.get_engine()
try:
# 将日期列表转换为SQL安全的格式
date_strings = [f"'{date}'" for date in trade_dates]
dates_clause = ", ".join(date_strings)
delete_query = f"DELETE FROM {table_name} WHERE trade_date IN ({dates_clause})"
with engine.connect() as connection:
connection.execute(text(delete_query))
connection.commit()
self.logger.info(f"成功从表 '{table_name}' 中删除 {len(trade_dates)} 个日期的数据")
return True
except Exception as e:
self.logger.error(f"删除表 '{table_name}' 中的数据时出错: {e}")
self.logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
return False
def query(self, sql_query, params=None):
"""
执行自定义SQL查询
参数:
sql_query (str): 要执行的SQL查询语句
params (dict, 可选): SQL参数化查询的参数
返回:
pandas.DataFrame 或 bool: 对于SELECT查询返回DataFrame其他查询返回是否执行成功
"""
engine = self.get_engine()
try:
# 判断是否是SELECT查询
is_select = sql_query.strip().upper().startswith("SELECT")
if is_select:
# 对于SELECT查询返回DataFrame
return pd.read_sql(sql_query, engine, params=params)
else:
# 对于非SELECT查询执行并返回成功状态
with engine.connect() as connection:
connection.execute(text(sql_query), params)
connection.commit()
self.logger.info(f"成功执行SQL查询")
return True
except Exception as e:
self.logger.error(f"执行SQL查询时出错: {e}")
self.logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
if is_select:
return pd.DataFrame()
return False