diff --git a/data_fetcher.py b/data_fetcher.py deleted file mode 100644 index 3be4b2c..0000000 --- a/data_fetcher.py +++ /dev/null @@ -1,154 +0,0 @@ -import traceback -from datetime import datetime, timedelta - -import pandas as pd -import tushare as ts -from tqdm import tqdm - -from config_manager import get_config_manager -from database_manager import DatabaseManager -from logger import get_logger - - -class DataFetcher: - """ - 数据获取器类,负责从Tushare获取各类数据并管理本地缓存 - """ - - def __init__(self): - # 加载配置并初始化tushare - self.config = get_config_manager() - ts.set_token(self.config.get('tushare_token')) - self.pro = ts.pro_api() - - # 初始化数据库管理器 - self.db_manager = DatabaseManager() - - # 获取日志器 - self.logger = get_logger() - - def get_trade_cal(self, start_date=None, end_date=None): - """ - 获取指定时间段内的交易日历 - 参数: - start_date (str): 开始日期,格式'YYYYMMDD' - end_date (str): 结束日期,格式'YYYYMMDD' - 返回: - list: 交易日期列表 - """ - if start_date is None: - start_date = (datetime.now() - timedelta(days=30)).strftime('%Y%m%d') - if end_date is None: - end_date = datetime.now().strftime('%Y%m%d') - try: - trade_cal_df = self.pro.trade_cal(exchange='', start_date=start_date, end_date=end_date) - return trade_cal_df[trade_cal_df['is_open'] == 1]['cal_date'].tolist() - except Exception as e: - self.logger.error(f"获取交易日历时出错: {e}") - self.logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}") - return [] - - def get(self, api_name, start_date=None, end_date=None, force_update=False, batch_size=100000): - """ - 获取指定时间段内的Tushare数据,使用数据库缓存,并分批处理大量数据 - - 参数: - api_name (str): Tushare的API名称,例如'moneyflow'或'moneyflow_ind_dc' - start_date (str): 开始日期,格式'YYYYMMDD' - end_date (str): 结束日期,格式'YYYYMMDD' - force_update (bool): 是否强制更新所选区域数据,默认为False - batch_size (int): 每批处理的最大行数,默认为100000 - - 返回: - pandas.DataFrame: 请求的数据 - """ - # 使用api_name作为表key查询表名 - table_name = api_name - - # 确保self.pro中存在对应的API方法 - if not hasattr(self.pro, api_name): - self.logger.error(f"Tushare API '{api_name}'不存在") - return pd.DataFrame() - - # 获取目标交易日历 - all_trade_dates = self.get_trade_cal(start_date, end_date) - - # 确定需要获取的日期 - if not force_update: - # 从数据库获取已有的交易日期 - existing_dates = self.db_manager.get_existing_trade_dates(table_name=table_name) - # 筛选出需要新获取的日期 - dates_to_fetch = [date for date in all_trade_dates if date not in existing_dates] - else: - # 强制更新,获取所有日期数据 - dates_to_fetch = all_trade_dates - self.logger.info(f"强制更新模式: 将更新 {len(dates_to_fetch)} 个交易日的数据") - - if not dates_to_fetch: - self.logger.info("所有数据已在数据库中,无需更新") - return self.db_manager.load_df_from_db(table_name=table_name) - - self.logger.info(f"需要获取 {len(dates_to_fetch)} 个交易日的数据") - - # 分批处理数据 - temp_batch = [] - total_rows = 0 - - # 获取数据 - for trade_date in tqdm(dates_to_fetch): - try: - # 动态调用Tushare API - api_method = getattr(self.pro, api_name) - df = api_method(trade_date=trade_date) - - if not df.empty: - temp_batch.append(df) - total_rows += len(df) - - # 当累积的数据量达到batch_size时,进行一次批量写入 - if total_rows >= batch_size: - self._process_batch(temp_batch, table_name, force_update) - self.logger.info(f"已处理 {total_rows} 行数据") - # 重置临时批次 - temp_batch = [] - total_rows = 0 - else: - self.logger.info(f"日期 {trade_date} 无数据") - except Exception as e: - self.logger.error(f"获取 {trade_date} 的数据时出错: {e}") - self.logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}") - - # 处理剩余的数据 - if temp_batch: - self._process_batch(temp_batch, table_name, force_update) - self.logger.info(f"已处理剩余 {total_rows} 行数据") - - self.logger.info("数据获取与处理完成") - return self.db_manager.load_df_from_db(table_name=table_name) - - def _process_batch(self, batch_dfs, table_name, force_update): - """ - 处理一批数据 - 参数: - batch_dfs (list): DataFrame列表 - table_name (str): 数据库表名称 - force_update (bool): 是否强制更新 - """ - if not batch_dfs: - return - - # 合并批次中的所有DataFrame - batch_df = pd.concat(batch_dfs, ignore_index=True) - - if force_update: - # 强制更新模式:先删除当前批次涉及的日期数据,然后插入新数据 - current_dates = batch_df['trade_date'].unique().tolist() - - # 删除这些日期的现有数据 - self.db_manager.delete_existing_data_by_dates(table_name, current_dates) - - # 插入新数据 - self.db_manager.save_df_to_db(batch_df, table_name=table_name, if_exists='append') - else: - # 普通追加模式 - self.db_manager.save_df_to_db(batch_df, table_name=table_name, if_exists='append') diff --git a/data_manager.py b/data_manager.py new file mode 100644 index 0000000..66fea4a --- /dev/null +++ b/data_manager.py @@ -0,0 +1,209 @@ +import traceback +from datetime import datetime, timedelta + +import pandas as pd +import tushare as ts +from tqdm import tqdm + +from config_manager import get_config_manager +from database_manager import DatabaseManager +from logger import get_logger + +# 模块级别的配置 +config = get_config_manager() +ts.set_token(config.get('tushare_token')) +pro = ts.pro_api() + +# 初始化数据库管理器 +db_manager = DatabaseManager() + +# 获取日志器 +logger = get_logger() + + +class DataFetcher: + """ + 数据获取器类,负责从Tushare获取各类数据并管理本地缓存 + """ + + @staticmethod + def get_basic(api_name): + """ + 获取基础数据,如股票列表等 + 参数: + api_name (str): Tushare的API名称,例如'stock_basic' + 返回: + pandas.DataFrame: 请求的数据 + """ + # 确保pro中存在对应的API方法 + if not hasattr(pro, api_name): + logger.error(f"Tushare API '{api_name}'不存在") + return pd.DataFrame() + + try: + df = getattr(pro, api_name)() + # 将数据保存到数据库 + db_manager.save_df_to_db(df, table_name=api_name, if_exists='replace') + return df + except Exception as e: + logger.error(f"获取基础数据时出错: {e}") + logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}") + return pd.DataFrame() + + @staticmethod + def get_trade_date(api_name, start_date=None, end_date=None, force_update=False, batch_size=100000): + """ + 获取指定时间段内的Tushare数据,使用数据库缓存,并分批处理大量数据 + + 参数: + api_name (str): Tushare的API名称,例如'moneyflow'或'moneyflow_ind_dc' + start_date (str): 开始日期,格式'YYYYMMDD' + end_date (str): 结束日期,格式'YYYYMMDD' + force_update (bool): 是否强制更新所选区域数据,默认为False + batch_size (int): 每批处理的最大行数,默认为100000 + + 返回: + pandas.DataFrame: 请求的数据 + """ + # 使用api_name作为表key查询表名 + table_name = api_name + + # 确保pro中存在对应的API方法 + if not hasattr(pro, api_name): + logger.error(f"Tushare API '{api_name}'不存在") + return pd.DataFrame() + + # 获取目标交易日历 + all_trade_dates = DataReader.get_trade_cal(start_date, end_date) + + # 确定需要获取的日期 + if not force_update: + # 从数据库获取已有的交易日期 + existing_dates = DataReader.get_existing_trade_dates(table_name=table_name) + # 筛选出需要新获取的日期 + dates_to_fetch = [date for date in all_trade_dates if date not in existing_dates] + else: + # 强制更新,获取所有日期数据 + dates_to_fetch = all_trade_dates + logger.info(f"强制更新模式: 将更新 {len(dates_to_fetch)} 个交易日的数据") + + if not dates_to_fetch: + logger.info("所有数据已在数据库中,无需更新") + return db_manager.load_df_from_db(table_name=table_name) + logger.info(f"需要获取 {len(dates_to_fetch)} 个交易日的数据") + + # 分批处理数据 + temp_batch = [] + total_rows = 0 + + # 获取数据 + for trade_date in tqdm(dates_to_fetch): + try: + # 动态调用Tushare API + api_method = getattr(pro, api_name) + df = api_method(trade_date=trade_date) + + if not df.empty: + temp_batch.append(df) + total_rows += len(df) + + # 当累积的数据量达到batch_size时,进行一次批量写入 + if total_rows >= batch_size: + DataFetcher.process_batch(temp_batch, table_name, force_update) + logger.info(f"已处理 {total_rows} 行数据") + # 重置临时批次 + temp_batch = [] + total_rows = 0 + else: + logger.info(f"日期 {trade_date} 无数据") + except Exception as e: + logger.error(f"获取 {trade_date} 的数据时出错: {e}") + logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}") + + # 处理剩余的数据 + if temp_batch: + DataFetcher.process_batch(temp_batch, table_name, force_update) + logger.info(f"已处理剩余 {total_rows} 行数据") + + logger.info("数据获取与处理完成") + return db_manager.load_df_from_db(table_name=table_name) + + @staticmethod + def process_batch(batch_dfs, table_name, force_update): + """ + 处理一批数据 + 参数: + batch_dfs (list): DataFrame列表 + table_name (str): 数据库表名称 + force_update (bool): 是否强制更新 + """ + if not batch_dfs: + return + # 合并批次中的所有DataFrame + batch_df = pd.concat(batch_dfs, ignore_index=True) + # 保存数据,传入force_update参数 + db_manager.save_df_to_db(batch_df, table_name=table_name, force_update=force_update) + + +class DataReader: + """ + 数据读取器类,负责从数据库读取数据 + """ + @staticmethod + def get_trade_cal(start_date=None, end_date=None, update=False): + """ + 获取指定时间段内的交易日历 + 参数: + start_date (str): 开始日期,格式'YYYYMMDD' + end_date (str): 结束日期,格式'YYYYMMDD' + update (bool): 是否更新交易日历,默认为False + 返回: + list: 交易日期列表 + """ + # 先检查表是否存在 + if not db_manager.table_exists('trade_cal') or update: + logger.debug(f"表 trade_cal 不存在") + # 自动拉取交易日历 + try: + DataFetcher.get_basic('trade_cal') + except Exception as e: + logger.error(f"自动拉取交易日历时出错: {e}") + logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}") + return [] + + + if start_date is None: + start_date = (datetime.now() - timedelta(days=30)).strftime('%Y%m%d') + if end_date is None: + end_date = datetime.now().strftime('%Y%m%d') + try: + df = db_manager.load_df_from_db('trade_cal', conditions=f"cal_date BETWEEN '{start_date}' AND '{end_date}'") + return df[df['is_open'] == 1]['cal_date'].tolist() + except Exception as e: + logger.error(f"获取交易日历时出错: {e}") + logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}") + return [] + + @staticmethod + def get_existing_trade_dates(table_name): + """ + 从数据库中获取已有的交易日期 + 参数: + table_name (str): 数据表名称 + 返回: + set: 已存在于数据库中的交易日期集合 + """ + # 先检查表是否存在 + if not db_manager.table_exists(table_name): + logger.debug(f"表 '{table_name}' 不存在") + return set() + + try: + # 使用query方法获取不重复的交易日期 + query_result = db_manager.query(f"SELECT DISTINCT trade_date FROM {table_name}") + # 将查询结果转换为集合并返回 + return set(query_result['trade_date'].values) + except Exception as e: + logger.error(f"获取已存在交易日期时出错: {e}") + logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}") + return set() \ No newline at end of file diff --git a/database_manager.py b/database_manager.py index 55b8693..454618b 100644 --- a/database_manager.py +++ b/database_manager.py @@ -44,30 +44,6 @@ class DatabaseManager: inspector = inspect(engine) return inspector.has_table(table_name) - def get_existing_trade_dates(self, table_name): - """ - 从数据库中获取已有的交易日期 - 参数: - table_name (str): 数据表名称 - 返回: - set: 已存在于数据库中的交易日期集合 - """ - # 先检查表是否存在 - if not self.table_exists(table_name): - self.logger.debug(f"表 '{table_name}' 不存在") - return set() - - engine = self.get_engine() - query = f"SELECT DISTINCT trade_date FROM {table_name}" - try: - with engine.connect() as connection: - result = connection.execute(text(query)) - return {row[0] for row in result} - except Exception as e: - self.logger.error(f"获取已存在交易日期时出错: {e}") - self.logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}") - return set() - def load_df_from_db(self, table_name, conditions=None): """ 从数据库中加载数据 @@ -88,13 +64,16 @@ class DatabaseManager: self.logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}") return pd.DataFrame() - def save_df_to_db(self, df, table_name, if_exists='append'): + def save_df_to_db(self, df, table_name, if_exists='append', force_update=False): """ 保存DataFrame到数据库 + 参数: df (pandas.DataFrame): 要保存的数据 table_name (str): 数据表名称 if_exists (str): 如果表存在时的操作: 'fail', 'replace', 或 'append' + force_update (bool): 是否强制更新(会先删除相同日期的数据再插入) + 返回: bool: 操作是否成功 """ @@ -104,7 +83,17 @@ class DatabaseManager: engine = self.get_engine() try: - df.to_sql(table_name, engine, if_exists=if_exists, index=False) + if force_update and 'trade_date' in df.columns: + # 强制更新模式:先删除当前批次涉及的日期数据,然后插入新数据 + current_dates = df['trade_date'].unique().tolist() + # 删除这些日期的现有数据 + self.delete_existing_data_by_dates(table_name, current_dates) + # 插入新数据(强制更新时始终使用append,因为已经删除了相关数据) + df.to_sql(table_name, engine, if_exists='append', index=False) + else: + # 普通模式 + df.to_sql(table_name, engine, if_exists=if_exists, index=False) + return True except Exception as e: self.logger.error(f"保存数据到数据库时出错: {e}") @@ -144,3 +133,35 @@ class DatabaseManager: self.logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}") return False + def query(self, sql_query, params=None): + """ + 执行自定义SQL查询 + + 参数: + sql_query (str): 要执行的SQL查询语句 + params (dict, 可选): SQL参数化查询的参数 + + 返回: + pandas.DataFrame 或 bool: 对于SELECT查询返回DataFrame,其他查询返回是否执行成功 + """ + engine = self.get_engine() + try: + # 判断是否是SELECT查询 + is_select = sql_query.strip().upper().startswith("SELECT") + + if is_select: + # 对于SELECT查询,返回DataFrame + return pd.read_sql(sql_query, engine, params=params) + else: + # 对于非SELECT查询,执行并返回成功状态 + with engine.connect() as connection: + connection.execute(text(sql_query), params) + connection.commit() + self.logger.info(f"成功执行SQL查询") + return True + except Exception as e: + self.logger.error(f"执行SQL查询时出错: {e}") + self.logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}") + if is_select: + return pd.DataFrame() + return False diff --git a/main.py b/main.py index 3e54e8b..12846bf 100644 --- a/main.py +++ b/main.py @@ -1,17 +1,14 @@ -from data_fetcher import DataFetcher +from data_manager import DataFetcher from money_flow_analyzer import MoneyflowAnalyzer if __name__ == "__main__": # 指定日期范围 - start_date = '20250401' + start_date = '20250101' end_date = None - # 创建数据获取器实例 - data_fetcher = DataFetcher() - # 获取板块资金流向数据 # 可以通过force_update=True参数强制更新指定日期范围的数据 - df = data_fetcher.get('moneyflow',start_date, end_date, force_update=True) + df = DataFetcher.get_trade_date('top_list',start_date, end_date, force_update=False) # analyzer = MoneyflowAnalyzer() # analyzer.main_flow_analyze(days_forward=10,use_consistent_samples=True) \ No newline at end of file