import traceback
from datetime import datetime, timedelta

import pandas as pd
import tushare as ts
from tqdm import tqdm

from config_manager import get_config_manager
from database_manager import DatabaseManager
from logger import get_logger


class DataFetcher:
    """
    数据获取器类,负责从Tushare获取各类数据并管理本地缓存
    """

    def __init__(self):
        # 加载配置并初始化tushare
        self.config = get_config_manager()
        ts.set_token(self.config.get('tushare_token'))
        self.pro = ts.pro_api()

        # 初始化数据库管理器
        self.db_manager = DatabaseManager()

        # 获取日志器
        self.logger = get_logger()

    def get_trade_cal(self, start_date=None, end_date=None):
        """
        获取指定时间段内的交易日历
        参数:
        start_date (str): 开始日期,格式'YYYYMMDD'
        end_date (str): 结束日期,格式'YYYYMMDD'
        返回:
        list: 交易日期列表
        """
        if start_date is None:
            start_date = (datetime.now() - timedelta(days=30)).strftime('%Y%m%d')
        if end_date is None:
            end_date = datetime.now().strftime('%Y%m%d')
        try:
            trade_cal_df = self.pro.trade_cal(exchange='', start_date=start_date, end_date=end_date)
            return trade_cal_df[trade_cal_df['is_open'] == 1]['cal_date'].tolist()
        except Exception as e:
            self.logger.error(f"获取交易日历时出错: {e}")
            self.logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")
            return []

    def get(self, api_name, start_date=None, end_date=None, force_update=False, batch_size=100000):
        """
        获取指定时间段内的Tushare数据,使用数据库缓存,并分批处理大量数据

        参数:
        api_name (str): Tushare的API名称,例如'moneyflow'或'moneyflow_ind_dc'
        start_date (str): 开始日期,格式'YYYYMMDD'
        end_date (str): 结束日期,格式'YYYYMMDD'
        force_update (bool): 是否强制更新所选区域数据,默认为False
        batch_size (int): 每批处理的最大行数,默认为100000

        返回:
        pandas.DataFrame: 请求的数据
        """
        # 使用api_name作为表key查询表名
        table_name = api_name

        # 确保self.pro中存在对应的API方法
        if not hasattr(self.pro, api_name):
            self.logger.error(f"Tushare API '{api_name}'不存在")
            return pd.DataFrame()

        # 获取目标交易日历
        all_trade_dates = self.get_trade_cal(start_date, end_date)

        # 确定需要获取的日期
        if not force_update:
            # 从数据库获取已有的交易日期
            existing_dates = self.db_manager.get_existing_trade_dates(table_name=table_name)
            # 筛选出需要新获取的日期
            dates_to_fetch = [date for date in all_trade_dates if date not in existing_dates]
        else:
            # 强制更新,获取所有日期数据
            dates_to_fetch = all_trade_dates
            self.logger.info(f"强制更新模式: 将更新 {len(dates_to_fetch)} 个交易日的数据")

        if not dates_to_fetch:
            self.logger.info("所有数据已在数据库中,无需更新")
            return self.db_manager.load_df_from_db(table_name=table_name)

        self.logger.info(f"需要获取 {len(dates_to_fetch)} 个交易日的数据")

        # 分批处理数据
        temp_batch = []
        total_rows = 0

        # 获取数据
        for trade_date in tqdm(dates_to_fetch):
            try:
                # 动态调用Tushare API
                api_method = getattr(self.pro, api_name)
                df = api_method(trade_date=trade_date)

                if not df.empty:
                    temp_batch.append(df)
                    total_rows += len(df)

                    # 当累积的数据量达到batch_size时,进行一次批量写入
                    if total_rows >= batch_size:
                        self._process_batch(temp_batch, table_name, force_update)
                        self.logger.info(f"已处理 {total_rows} 行数据")
                        # 重置临时批次
                        temp_batch = []
                        total_rows = 0
                else:
                    self.logger.info(f"日期 {trade_date} 无数据")
            except Exception as e:
                self.logger.error(f"获取 {trade_date} 的数据时出错: {e}")
                self.logger.debug(f"完整的错误追踪信息:\n{traceback.format_exc()}")

        # 处理剩余的数据
        if temp_batch:
            self._process_batch(temp_batch, table_name, force_update)
            self.logger.info(f"已处理剩余 {total_rows} 行数据")

        self.logger.info("数据获取与处理完成")
        return self.db_manager.load_df_from_db(table_name=table_name)

    def _process_batch(self, batch_dfs, table_name, force_update):
        """
        处理一批数据
        参数:
        batch_dfs (list): DataFrame列表
        table_name (str): 数据库表名称
        force_update (bool): 是否强制更新
        """
        if not batch_dfs:
            return

        # 合并批次中的所有DataFrame
        batch_df = pd.concat(batch_dfs, ignore_index=True)

        if force_update:
            # 强制更新模式:先删除当前批次涉及的日期数据,然后插入新数据
            current_dates = batch_df['trade_date'].unique().tolist()

            # 删除这些日期的现有数据
            self.db_manager.delete_existing_data_by_dates(table_name, current_dates)

            # 插入新数据
            self.db_manager.save_df_to_db(batch_df, table_name=table_name, if_exists='append')
        else:
            # 普通追加模式
            self.db_manager.save_df_to_db(batch_df, table_name=table_name, if_exists='append')