backtrader/database_manager.py

120 lines
4.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import traceback
import pandas as pd
from sqlalchemy import create_engine, text, inspect
from config_manager import get_config_manager
from logger import get_logger
class DatabaseManager:
"""
数据库管理类,负责数据的存储、查询和管理
"""
def __init__(self):
self.config = get_config_manager()
self._engine = None
self.logger = get_logger()
def get_engine(self):
"""获取SQLite数据库引擎如果不存在则创建"""
if self._engine is not None:
return self._engine
db_path = self.config.get('sqlite.path')
# 确保数据库目录存在
os.makedirs(os.path.dirname(db_path), exist_ok=True)
# 创建SQLite数据库引擎
self._engine = create_engine(f'sqlite:///{db_path}', echo=False)
return self._engine
def get_table_name(self, key):
"""根据表键名获取实际表名"""
return self.config.get(f'sqlite.table_name.{key}', key)
def table_exists(self, table_key):
"""
检查表是否存在
参数:
table_key (str): 数据表键名
返回:
bool: 表是否存在
"""
table_name = self.get_table_name(table_key)
engine = self.get_engine()
inspector = inspect(engine)
return inspector.has_table(table_name)
def get_existing_trade_dates(self, table_key):
"""
从数据库中获取已有的交易日期
参数:
table_key (str): 数据表键名
返回:
set: 已存在于数据库中的交易日期集合
"""
# 先检查表是否存在
if not self.table_exists(table_key):
self.logger.debug(f"'{self.get_table_name(table_key)}' 不存在")
return set()
table_name = self.get_table_name(table_key)
engine = self.get_engine()
query = f"SELECT DISTINCT trade_date FROM {table_name}"
try:
with engine.connect() as connection:
result = connection.execute(text(query))
return {row[0] for row in result}
except Exception as e:
self.logger.error(f"获取已存在交易日期时出错: {e}")
self.logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
return set()
def load_df_from_db(self, table_key, conditions=None):
"""
从数据库中加载数据
参数:
table_key (str): 表键名
conditions (str): 过滤条件,如 "trade_date > '20230101'"
返回:
pandas.DataFrame: 查询结果
"""
table_name = self.get_table_name(table_key)
engine = self.get_engine()
query = f"SELECT * FROM {table_name}"
if conditions:
query += f" WHERE {conditions}"
try:
return pd.read_sql(query, engine)
except Exception as e:
self.logger.error(f"从数据库加载数据时出错: {e}")
self.logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
return pd.DataFrame()
def save_df_to_db(self, df, table_key, if_exists='append'):
"""
保存DataFrame到数据库
参数:
df (pandas.DataFrame): 要保存的数据
table_key (str): 表键名
if_exists (str): 如果表存在时的操作: 'fail', 'replace', 或 'append'
返回:
bool: 操作是否成功
"""
if df.empty:
self.logger.warning("警告: 尝试保存空的DataFrame到数据库")
return False
table_name = self.get_table_name(table_key)
engine = self.get_engine()
try:
df.to_sql(table_name, engine, if_exists=if_exists, index=False)
return True
except Exception as e:
self.logger.error(f"保存数据到数据库时出错: {e}")
self.logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
return False