🔨 refactor(data): 重构数据管理模块,提取独立的DataReader类并优化API接口
为了提升代码结构和可维护性,将data_fetcher.py重命名为data_manager.py,并进行以下重构: 1. 将实例变量移至模块级别配置 2. 将实例方法转换为静态方法 3. 提取新的DataReader类用于数据读取操作 4. 在DatabaseManager中添加通用查询方法 5. 优化数据获取与缓存逻辑
This commit is contained in:
parent
2c0eebf923
commit
5c22360b58
154
data_fetcher.py
154
data_fetcher.py
@ -1,154 +0,0 @@
|
||||
import traceback
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import pandas as pd
|
||||
import tushare as ts
|
||||
from tqdm import tqdm
|
||||
|
||||
from config_manager import get_config_manager
|
||||
from database_manager import DatabaseManager
|
||||
from logger import get_logger
|
||||
|
||||
|
||||
class DataFetcher:
|
||||
"""
|
||||
数据获取器类,负责从Tushare获取各类数据并管理本地缓存
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# 加载配置并初始化tushare
|
||||
self.config = get_config_manager()
|
||||
ts.set_token(self.config.get('tushare_token'))
|
||||
self.pro = ts.pro_api()
|
||||
|
||||
# 初始化数据库管理器
|
||||
self.db_manager = DatabaseManager()
|
||||
|
||||
# 获取日志器
|
||||
self.logger = get_logger()
|
||||
|
||||
def get_trade_cal(self, start_date=None, end_date=None):
|
||||
"""
|
||||
获取指定时间段内的交易日历
|
||||
参数:
|
||||
start_date (str): 开始日期,格式'YYYYMMDD'
|
||||
end_date (str): 结束日期,格式'YYYYMMDD'
|
||||
返回:
|
||||
list: 交易日期列表
|
||||
"""
|
||||
if start_date is None:
|
||||
start_date = (datetime.now() - timedelta(days=30)).strftime('%Y%m%d')
|
||||
if end_date is None:
|
||||
end_date = datetime.now().strftime('%Y%m%d')
|
||||
try:
|
||||
trade_cal_df = self.pro.trade_cal(exchange='', start_date=start_date, end_date=end_date)
|
||||
return trade_cal_df[trade_cal_df['is_open'] == 1]['cal_date'].tolist()
|
||||
except Exception as e:
|
||||
self.logger.error(f"获取交易日历时出错: {e}")
|
||||
self.logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
|
||||
return []
|
||||
|
||||
def get(self, api_name, start_date=None, end_date=None, force_update=False, batch_size=100000):
|
||||
"""
|
||||
获取指定时间段内的Tushare数据,使用数据库缓存,并分批处理大量数据
|
||||
|
||||
参数:
|
||||
api_name (str): Tushare的API名称,例如'moneyflow'或'moneyflow_ind_dc'
|
||||
start_date (str): 开始日期,格式'YYYYMMDD'
|
||||
end_date (str): 结束日期,格式'YYYYMMDD'
|
||||
force_update (bool): 是否强制更新所选区域数据,默认为False
|
||||
batch_size (int): 每批处理的最大行数,默认为100000
|
||||
|
||||
返回:
|
||||
pandas.DataFrame: 请求的数据
|
||||
"""
|
||||
# 使用api_name作为表key查询表名
|
||||
table_name = api_name
|
||||
|
||||
# 确保self.pro中存在对应的API方法
|
||||
if not hasattr(self.pro, api_name):
|
||||
self.logger.error(f"Tushare API '{api_name}'不存在")
|
||||
return pd.DataFrame()
|
||||
|
||||
# 获取目标交易日历
|
||||
all_trade_dates = self.get_trade_cal(start_date, end_date)
|
||||
|
||||
# 确定需要获取的日期
|
||||
if not force_update:
|
||||
# 从数据库获取已有的交易日期
|
||||
existing_dates = self.db_manager.get_existing_trade_dates(table_name=table_name)
|
||||
# 筛选出需要新获取的日期
|
||||
dates_to_fetch = [date for date in all_trade_dates if date not in existing_dates]
|
||||
else:
|
||||
# 强制更新,获取所有日期数据
|
||||
dates_to_fetch = all_trade_dates
|
||||
self.logger.info(f"强制更新模式: 将更新 {len(dates_to_fetch)} 个交易日的数据")
|
||||
|
||||
if not dates_to_fetch:
|
||||
self.logger.info("所有数据已在数据库中,无需更新")
|
||||
return self.db_manager.load_df_from_db(table_name=table_name)
|
||||
|
||||
self.logger.info(f"需要获取 {len(dates_to_fetch)} 个交易日的数据")
|
||||
|
||||
# 分批处理数据
|
||||
temp_batch = []
|
||||
total_rows = 0
|
||||
|
||||
# 获取数据
|
||||
for trade_date in tqdm(dates_to_fetch):
|
||||
try:
|
||||
# 动态调用Tushare API
|
||||
api_method = getattr(self.pro, api_name)
|
||||
df = api_method(trade_date=trade_date)
|
||||
|
||||
if not df.empty:
|
||||
temp_batch.append(df)
|
||||
total_rows += len(df)
|
||||
|
||||
# 当累积的数据量达到batch_size时,进行一次批量写入
|
||||
if total_rows >= batch_size:
|
||||
self._process_batch(temp_batch, table_name, force_update)
|
||||
self.logger.info(f"已处理 {total_rows} 行数据")
|
||||
# 重置临时批次
|
||||
temp_batch = []
|
||||
total_rows = 0
|
||||
else:
|
||||
self.logger.info(f"日期 {trade_date} 无数据")
|
||||
except Exception as e:
|
||||
self.logger.error(f"获取 {trade_date} 的数据时出错: {e}")
|
||||
self.logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
|
||||
|
||||
# 处理剩余的数据
|
||||
if temp_batch:
|
||||
self._process_batch(temp_batch, table_name, force_update)
|
||||
self.logger.info(f"已处理剩余 {total_rows} 行数据")
|
||||
|
||||
self.logger.info("数据获取与处理完成")
|
||||
return self.db_manager.load_df_from_db(table_name=table_name)
|
||||
|
||||
def _process_batch(self, batch_dfs, table_name, force_update):
|
||||
"""
|
||||
处理一批数据
|
||||
参数:
|
||||
batch_dfs (list): DataFrame列表
|
||||
table_name (str): 数据库表名称
|
||||
force_update (bool): 是否强制更新
|
||||
"""
|
||||
if not batch_dfs:
|
||||
return
|
||||
|
||||
# 合并批次中的所有DataFrame
|
||||
batch_df = pd.concat(batch_dfs, ignore_index=True)
|
||||
|
||||
if force_update:
|
||||
# 强制更新模式:先删除当前批次涉及的日期数据,然后插入新数据
|
||||
current_dates = batch_df['trade_date'].unique().tolist()
|
||||
|
||||
# 删除这些日期的现有数据
|
||||
self.db_manager.delete_existing_data_by_dates(table_name, current_dates)
|
||||
|
||||
# 插入新数据
|
||||
self.db_manager.save_df_to_db(batch_df, table_name=table_name, if_exists='append')
|
||||
else:
|
||||
# 普通追加模式
|
||||
self.db_manager.save_df_to_db(batch_df, table_name=table_name, if_exists='append')
|
209
data_manager.py
Normal file
209
data_manager.py
Normal file
@ -0,0 +1,209 @@
|
||||
import traceback
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import pandas as pd
|
||||
import tushare as ts
|
||||
from tqdm import tqdm
|
||||
|
||||
from config_manager import get_config_manager
|
||||
from database_manager import DatabaseManager
|
||||
from logger import get_logger
|
||||
|
||||
# 模块级别的配置
|
||||
config = get_config_manager()
|
||||
ts.set_token(config.get('tushare_token'))
|
||||
pro = ts.pro_api()
|
||||
|
||||
# 初始化数据库管理器
|
||||
db_manager = DatabaseManager()
|
||||
|
||||
# 获取日志器
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class DataFetcher:
|
||||
"""
|
||||
数据获取器类,负责从Tushare获取各类数据并管理本地缓存
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def get_basic(api_name):
|
||||
"""
|
||||
获取基础数据,如股票列表等
|
||||
参数:
|
||||
api_name (str): Tushare的API名称,例如'stock_basic'
|
||||
返回:
|
||||
pandas.DataFrame: 请求的数据
|
||||
"""
|
||||
# 确保pro中存在对应的API方法
|
||||
if not hasattr(pro, api_name):
|
||||
logger.error(f"Tushare API '{api_name}'不存在")
|
||||
return pd.DataFrame()
|
||||
|
||||
try:
|
||||
df = getattr(pro, api_name)()
|
||||
# 将数据保存到数据库
|
||||
db_manager.save_df_to_db(df, table_name=api_name, if_exists='replace')
|
||||
return df
|
||||
except Exception as e:
|
||||
logger.error(f"获取基础数据时出错: {e}")
|
||||
logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
|
||||
return pd.DataFrame()
|
||||
|
||||
@staticmethod
|
||||
def get_trade_date(api_name, start_date=None, end_date=None, force_update=False, batch_size=100000):
|
||||
"""
|
||||
获取指定时间段内的Tushare数据,使用数据库缓存,并分批处理大量数据
|
||||
|
||||
参数:
|
||||
api_name (str): Tushare的API名称,例如'moneyflow'或'moneyflow_ind_dc'
|
||||
start_date (str): 开始日期,格式'YYYYMMDD'
|
||||
end_date (str): 结束日期,格式'YYYYMMDD'
|
||||
force_update (bool): 是否强制更新所选区域数据,默认为False
|
||||
batch_size (int): 每批处理的最大行数,默认为100000
|
||||
|
||||
返回:
|
||||
pandas.DataFrame: 请求的数据
|
||||
"""
|
||||
# 使用api_name作为表key查询表名
|
||||
table_name = api_name
|
||||
|
||||
# 确保pro中存在对应的API方法
|
||||
if not hasattr(pro, api_name):
|
||||
logger.error(f"Tushare API '{api_name}'不存在")
|
||||
return pd.DataFrame()
|
||||
|
||||
# 获取目标交易日历
|
||||
all_trade_dates = DataReader.get_trade_cal(start_date, end_date)
|
||||
|
||||
# 确定需要获取的日期
|
||||
if not force_update:
|
||||
# 从数据库获取已有的交易日期
|
||||
existing_dates = DataReader.get_existing_trade_dates(table_name=table_name)
|
||||
# 筛选出需要新获取的日期
|
||||
dates_to_fetch = [date for date in all_trade_dates if date not in existing_dates]
|
||||
else:
|
||||
# 强制更新,获取所有日期数据
|
||||
dates_to_fetch = all_trade_dates
|
||||
logger.info(f"强制更新模式: 将更新 {len(dates_to_fetch)} 个交易日的数据")
|
||||
|
||||
if not dates_to_fetch:
|
||||
logger.info("所有数据已在数据库中,无需更新")
|
||||
return db_manager.load_df_from_db(table_name=table_name)
|
||||
logger.info(f"需要获取 {len(dates_to_fetch)} 个交易日的数据")
|
||||
|
||||
# 分批处理数据
|
||||
temp_batch = []
|
||||
total_rows = 0
|
||||
|
||||
# 获取数据
|
||||
for trade_date in tqdm(dates_to_fetch):
|
||||
try:
|
||||
# 动态调用Tushare API
|
||||
api_method = getattr(pro, api_name)
|
||||
df = api_method(trade_date=trade_date)
|
||||
|
||||
if not df.empty:
|
||||
temp_batch.append(df)
|
||||
total_rows += len(df)
|
||||
|
||||
# 当累积的数据量达到batch_size时,进行一次批量写入
|
||||
if total_rows >= batch_size:
|
||||
DataFetcher.process_batch(temp_batch, table_name, force_update)
|
||||
logger.info(f"已处理 {total_rows} 行数据")
|
||||
# 重置临时批次
|
||||
temp_batch = []
|
||||
total_rows = 0
|
||||
else:
|
||||
logger.info(f"日期 {trade_date} 无数据")
|
||||
except Exception as e:
|
||||
logger.error(f"获取 {trade_date} 的数据时出错: {e}")
|
||||
logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
|
||||
|
||||
# 处理剩余的数据
|
||||
if temp_batch:
|
||||
DataFetcher.process_batch(temp_batch, table_name, force_update)
|
||||
logger.info(f"已处理剩余 {total_rows} 行数据")
|
||||
|
||||
logger.info("数据获取与处理完成")
|
||||
return db_manager.load_df_from_db(table_name=table_name)
|
||||
|
||||
@staticmethod
|
||||
def process_batch(batch_dfs, table_name, force_update):
|
||||
"""
|
||||
处理一批数据
|
||||
参数:
|
||||
batch_dfs (list): DataFrame列表
|
||||
table_name (str): 数据库表名称
|
||||
force_update (bool): 是否强制更新
|
||||
"""
|
||||
if not batch_dfs:
|
||||
return
|
||||
# 合并批次中的所有DataFrame
|
||||
batch_df = pd.concat(batch_dfs, ignore_index=True)
|
||||
# 保存数据,传入force_update参数
|
||||
db_manager.save_df_to_db(batch_df, table_name=table_name, force_update=force_update)
|
||||
|
||||
|
||||
class DataReader:
|
||||
"""
|
||||
数据读取器类,负责从数据库读取数据
|
||||
"""
|
||||
@staticmethod
|
||||
def get_trade_cal(start_date=None, end_date=None, update=False):
|
||||
"""
|
||||
获取指定时间段内的交易日历
|
||||
参数:
|
||||
start_date (str): 开始日期,格式'YYYYMMDD'
|
||||
end_date (str): 结束日期,格式'YYYYMMDD'
|
||||
update (bool): 是否更新交易日历,默认为False
|
||||
返回:
|
||||
list: 交易日期列表
|
||||
"""
|
||||
# 先检查表是否存在
|
||||
if not db_manager.table_exists('trade_cal') or update:
|
||||
logger.debug(f"表 trade_cal 不存在")
|
||||
# 自动拉取交易日历
|
||||
try:
|
||||
DataFetcher.get_basic('trade_cal')
|
||||
except Exception as e:
|
||||
logger.error(f"自动拉取交易日历时出错: {e}")
|
||||
logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
|
||||
return []
|
||||
|
||||
|
||||
if start_date is None:
|
||||
start_date = (datetime.now() - timedelta(days=30)).strftime('%Y%m%d')
|
||||
if end_date is None:
|
||||
end_date = datetime.now().strftime('%Y%m%d')
|
||||
try:
|
||||
df = db_manager.load_df_from_db('trade_cal', conditions=f"cal_date BETWEEN '{start_date}' AND '{end_date}'")
|
||||
return df[df['is_open'] == 1]['cal_date'].tolist()
|
||||
except Exception as e:
|
||||
logger.error(f"获取交易日历时出错: {e}")
|
||||
logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def get_existing_trade_dates(table_name):
|
||||
"""
|
||||
从数据库中获取已有的交易日期
|
||||
参数:
|
||||
table_name (str): 数据表名称
|
||||
返回:
|
||||
set: 已存在于数据库中的交易日期集合
|
||||
"""
|
||||
# 先检查表是否存在
|
||||
if not db_manager.table_exists(table_name):
|
||||
logger.debug(f"表 '{table_name}' 不存在")
|
||||
return set()
|
||||
|
||||
try:
|
||||
# 使用query方法获取不重复的交易日期
|
||||
query_result = db_manager.query(f"SELECT DISTINCT trade_date FROM {table_name}")
|
||||
# 将查询结果转换为集合并返回
|
||||
return set(query_result['trade_date'].values)
|
||||
except Exception as e:
|
||||
logger.error(f"获取已存在交易日期时出错: {e}")
|
||||
logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
|
||||
return set()
|
@ -44,30 +44,6 @@ class DatabaseManager:
|
||||
inspector = inspect(engine)
|
||||
return inspector.has_table(table_name)
|
||||
|
||||
def get_existing_trade_dates(self, table_name):
|
||||
"""
|
||||
从数据库中获取已有的交易日期
|
||||
参数:
|
||||
table_name (str): 数据表名称
|
||||
返回:
|
||||
set: 已存在于数据库中的交易日期集合
|
||||
"""
|
||||
# 先检查表是否存在
|
||||
if not self.table_exists(table_name):
|
||||
self.logger.debug(f"表 '{table_name}' 不存在")
|
||||
return set()
|
||||
|
||||
engine = self.get_engine()
|
||||
query = f"SELECT DISTINCT trade_date FROM {table_name}"
|
||||
try:
|
||||
with engine.connect() as connection:
|
||||
result = connection.execute(text(query))
|
||||
return {row[0] for row in result}
|
||||
except Exception as e:
|
||||
self.logger.error(f"获取已存在交易日期时出错: {e}")
|
||||
self.logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
|
||||
return set()
|
||||
|
||||
def load_df_from_db(self, table_name, conditions=None):
|
||||
"""
|
||||
从数据库中加载数据
|
||||
@ -88,13 +64,16 @@ class DatabaseManager:
|
||||
self.logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
|
||||
return pd.DataFrame()
|
||||
|
||||
def save_df_to_db(self, df, table_name, if_exists='append'):
|
||||
def save_df_to_db(self, df, table_name, if_exists='append', force_update=False):
|
||||
"""
|
||||
保存DataFrame到数据库
|
||||
|
||||
参数:
|
||||
df (pandas.DataFrame): 要保存的数据
|
||||
table_name (str): 数据表名称
|
||||
if_exists (str): 如果表存在时的操作: 'fail', 'replace', 或 'append'
|
||||
force_update (bool): 是否强制更新(会先删除相同日期的数据再插入)
|
||||
|
||||
返回:
|
||||
bool: 操作是否成功
|
||||
"""
|
||||
@ -104,7 +83,17 @@ class DatabaseManager:
|
||||
|
||||
engine = self.get_engine()
|
||||
try:
|
||||
if force_update and 'trade_date' in df.columns:
|
||||
# 强制更新模式:先删除当前批次涉及的日期数据,然后插入新数据
|
||||
current_dates = df['trade_date'].unique().tolist()
|
||||
# 删除这些日期的现有数据
|
||||
self.delete_existing_data_by_dates(table_name, current_dates)
|
||||
# 插入新数据(强制更新时始终使用append,因为已经删除了相关数据)
|
||||
df.to_sql(table_name, engine, if_exists='append', index=False)
|
||||
else:
|
||||
# 普通模式
|
||||
df.to_sql(table_name, engine, if_exists=if_exists, index=False)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
self.logger.error(f"保存数据到数据库时出错: {e}")
|
||||
@ -144,3 +133,35 @@ class DatabaseManager:
|
||||
self.logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
|
||||
return False
|
||||
|
||||
def query(self, sql_query, params=None):
|
||||
"""
|
||||
执行自定义SQL查询
|
||||
|
||||
参数:
|
||||
sql_query (str): 要执行的SQL查询语句
|
||||
params (dict, 可选): SQL参数化查询的参数
|
||||
|
||||
返回:
|
||||
pandas.DataFrame 或 bool: 对于SELECT查询返回DataFrame,其他查询返回是否执行成功
|
||||
"""
|
||||
engine = self.get_engine()
|
||||
try:
|
||||
# 判断是否是SELECT查询
|
||||
is_select = sql_query.strip().upper().startswith("SELECT")
|
||||
|
||||
if is_select:
|
||||
# 对于SELECT查询,返回DataFrame
|
||||
return pd.read_sql(sql_query, engine, params=params)
|
||||
else:
|
||||
# 对于非SELECT查询,执行并返回成功状态
|
||||
with engine.connect() as connection:
|
||||
connection.execute(text(sql_query), params)
|
||||
connection.commit()
|
||||
self.logger.info(f"成功执行SQL查询")
|
||||
return True
|
||||
except Exception as e:
|
||||
self.logger.error(f"执行SQL查询时出错: {e}")
|
||||
self.logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
|
||||
if is_select:
|
||||
return pd.DataFrame()
|
||||
return False
|
||||
|
9
main.py
9
main.py
@ -1,17 +1,14 @@
|
||||
from data_fetcher import DataFetcher
|
||||
from data_manager import DataFetcher
|
||||
from money_flow_analyzer import MoneyflowAnalyzer
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 指定日期范围
|
||||
start_date = '20250401'
|
||||
start_date = '20250101'
|
||||
end_date = None
|
||||
|
||||
# 创建数据获取器实例
|
||||
data_fetcher = DataFetcher()
|
||||
|
||||
# 获取板块资金流向数据
|
||||
# 可以通过force_update=True参数强制更新指定日期范围的数据
|
||||
df = data_fetcher.get('moneyflow',start_date, end_date, force_update=True)
|
||||
df = DataFetcher.get_trade_date('top_list',start_date, end_date, force_update=False)
|
||||
|
||||
# analyzer = MoneyflowAnalyzer()
|
||||
# analyzer.main_flow_analyze(days_forward=10,use_consistent_samples=True)
|
Loading…
Reference in New Issue
Block a user