import pandas as pd
import tushare as ts
import numpy as np
import os
import multiprocessing
from multiprocessing import Pool
from functools import partial
from sqlalchemy import create_engine, text
from datetime import datetime, timedelta

from utils import load_config


def create_engine_from_config(config):
    mysql = config['mysql']
    connection_string = f"mysql+pymysql://{mysql['user']}:{mysql['password']}@{mysql['host']}:{mysql['port']}/{mysql['database']}?charset={mysql['charset']}&use_unicode=1"
    return create_engine(connection_string)


def create_metadata_table(engine):
    """创建股票元数据表"""
    # 从sql文件读取创建表的SQL
    with open('sql/meta.sql', 'r', encoding='utf-8') as f:
        create_table_sql = f.read()

    with engine.connect() as conn:
        conn.execute(text(create_table_sql))
        conn.commit()
    print("股票元数据表创建成功")


def create_stock_table(engine, ts_code):
    """为指定股票创建数据表"""
    # 确保sql目录存在
    os.makedirs('sql', exist_ok=True)

    # 读取模板SQL
    with open('sql/stock_table_template.sql', 'r', encoding='utf-8') as f:
        template_sql = f.read()

    # 替换模板中的symbol
    create_table_sql = template_sql.format(symbol=ts_code.split('.')[0])

    # 执行创建表的SQL
    with engine.connect() as conn:
        conn.execute(text(create_table_sql))
        conn.commit()

    print(f"股票({ts_code})数据表创建成功")


def download_stock_metadata(pro):
    """下载股票基础信息"""
    df = pro.stock_basic(exchange='', list_status='L',
                         fields='ts_code,symbol,name,area,industry,fullname,enname,cnspell,market,exchange,curr_type,list_status,list_date,delist_date,is_hs,act_name,act_ent_type')
    # 添加状态字段
    df['status'] = 0  # 未初始化
    df['remark'] = '自动导入'
    df['last_full_update'] = None
    df['last_incremental_update'] = None
    df['data_start_date'] = None
    df['data_end_date'] = None
    df['record_count'] = 0
    df['latest_price'] = None
    df['latest_date'] = None
    df['latest_pe_ttm'] = None
    df['latest_pb'] = None
    df['latest_total_mv'] = None

    return df


def save_metadata(df, engine):
    """保存元数据到数据库,只更新指定字段发生变化的记录,处理日期格式问题"""
    # 需要监控变化的基础元数据字段
    metadata_fields = [
        'ts_code', 'symbol', 'name', 'area', 'industry', 'fullname',
        'enname', 'cnspell', 'market', 'exchange', 'curr_type',
        'list_status', 'list_date', 'delist_date', 'is_hs',
        'act_name', 'act_ent_type'
    ]

    # 日期字段列表
    date_fields = ['list_date', 'delist_date']

    # 从数据库中读取现有的元数据
    try:
        with engine.connect() as conn:
            # 只查询我们关心的字段
            query = f"SELECT {', '.join(metadata_fields)} FROM stock_metadata"
            existing_df = pd.read_sql(query, conn)
    except Exception as e:
        print(f"读取现有元数据时出错: {str(e)}")
        existing_df = pd.DataFrame()  # 表不存在时创建空DataFrame

    if existing_df.empty:
        # 如果表是空的,直接插入所有数据
        result = df.to_sql('stock_metadata', engine, index=False,
                           if_exists='append', chunksize=1000)
        print(f"成功保存{result}条股票元数据(新建)")
        return result
    else:
        # 数据预处理:处理日期格式
        df_processed = df.copy()
        existing_processed = existing_df.copy()

        # 标准化日期格式进行比较
        def normalize_date(date_str):
            if pd.isna(date_str) or date_str == '' or date_str == 'None' or date_str is None:
                return ''

            # 去除所有非数字字符
            date_str = str(date_str).strip()
            digits_only = ''.join(c for c in date_str if c.isdigit())

            # 如果是8位数字形式的日期,返回标准格式
            if len(digits_only) == 8:
                return digits_only
            return ''

        # 对两个DataFrame应用日期标准化
        for field in date_fields:
            if field in df_processed.columns and field in existing_processed.columns:
                df_processed[field] = df_processed[field].apply(normalize_date)
                existing_processed[field] = existing_processed[field].apply(normalize_date)

        # 对其他非日期字段进行标准化处理
        for col in [f for f in metadata_fields if f not in date_fields and f != 'ts_code']:
            if col in df_processed.columns and col in existing_processed.columns:
                # 将两个DataFrame的相同列转换为相同的数据类型(字符串类型)
                df_processed[col] = df_processed[col].astype(str)
                existing_processed[col] = existing_processed[col].astype(str)

                # 处理NaN值
                df_processed[col] = df_processed[col].fillna('').str.strip()
                existing_processed[col] = existing_processed[col].fillna('').str.strip()

        # 找出新增的记录
        existing_ts_codes = set(existing_processed['ts_code'])
        new_records = df[~df['ts_code'].isin(existing_ts_codes)].copy()

        # 找出需要更新的记录
        changed_records = []
        unchanged_count = 0

        # 对于每个已存在的ts_code,检查是否有变化
        for ts_code in df_processed[df_processed['ts_code'].isin(existing_ts_codes)]['ts_code']:
            # 获取新旧记录(已经标准化后的)
            new_record = df_processed[df_processed['ts_code'] == ts_code].iloc[0]
            old_record = existing_processed[existing_processed['ts_code'] == ts_code].iloc[0]

            # 同时取原始记录用于更新操作
            original_record = df[df['ts_code'] == ts_code].iloc[0]

            # 检查是否有变化
            has_change = False
            changed_fields = []

            for field in metadata_fields:
                if field == 'ts_code':  # 跳过主键
                    continue

                new_val = new_record[field]
                old_val = old_record[field]

                if new_val != old_val:
                    has_change = True
                    changed_fields.append(field)
                    print(f"发现变化 - {ts_code} 的 {field}: '{old_val}' -> '{new_val}'")

            if has_change:
                changed_records.append({
                    'record': original_record,  # 使用原始记录,保持原始格式
                    'changed_fields': changed_fields
                })
            else:
                unchanged_count += 1

        # 插入新记录
        new_count = 0
        if not new_records.empty:
            new_count = new_records.to_sql('stock_metadata', engine, index=False,
                                           if_exists='append', chunksize=1000)

        # 更新变化的记录
        updated_count = 0
        for change_info in changed_records:
            record = change_info['record']
            fields = change_info['changed_fields']

            # 构建更新语句,只更新变化的字段
            fields_to_update = []
            params = {'ts_code': record['ts_code']}

            for field in fields:
                fields_to_update.append(f"{field} = :{field}")
                params[field] = record[field]

            if fields_to_update:
                update_stmt = text(f"""
                    UPDATE stock_metadata 
                    SET {', '.join(fields_to_update)}
                    WHERE ts_code = :ts_code
                """)

                with engine.connect() as conn:
                    conn.execute(update_stmt, params)
                    updated_count += 1
                    conn.commit()

        print(f"元数据更新统计:")
        print(f"  • 新增记录: {new_count}条")
        print(f"  • 更新记录: {updated_count}条")
        print(f"  • 无变化记录: {unchanged_count}条")
        print(f"  • 总处理记录: {new_count + updated_count + unchanged_count}条")

        return new_count + updated_count


def download_stock_data(pro, ts_code, start_date, end_date):
    """下载股票的所有类型数据并合并"""
    # 下载daily价格数据
    daily_df = pro.daily(ts_code=ts_code, start_date=start_date, end_date=end_date)
    if daily_df.empty:
        print(f"警告:{ts_code}没有找到daily数据")
        return pd.DataFrame()

    # 下载daily_basic数据
    daily_basic_df = pro.daily_basic(ts_code=ts_code, start_date=start_date, end_date=end_date)

    # 下载moneyflow数据
    moneyflow_df = pro.moneyflow(ts_code=ts_code, start_date=start_date, end_date=end_date)

    # 确保每个DataFrame都有trade_date列作为合并键
    if 'trade_date' not in daily_df.columns:
        print(f"错误:{ts_code}的daily数据缺少trade_date列")
        return pd.DataFrame()

    # 为方便处理,确保所有日期列是字符串类型
    daily_df['trade_date'] = daily_df['trade_date'].astype(str)

    # 通过merge而不是join合并数据,这样可以更好地控制列
    result_df = daily_df

    # 合并daily_basic数据
    if not daily_basic_df.empty:
        daily_basic_df['trade_date'] = daily_basic_df['trade_date'].astype(str)

        # 识别重叠的列(除了ts_code和trade_date)
        overlap_cols = list(set(result_df.columns) & set(daily_basic_df.columns) - {'ts_code', 'trade_date'})

        # 从daily_basic中排除这些重叠列
        daily_basic_df_filtered = daily_basic_df.drop(columns=overlap_cols)

        # 合并数据
        result_df = pd.merge(
            result_df,
            daily_basic_df_filtered,
            on=['ts_code', 'trade_date'],
            how='left'
        )

    # 合并moneyflow数据
    if not moneyflow_df.empty:
        moneyflow_df['trade_date'] = moneyflow_df['trade_date'].astype(str)

        # 识别重叠的列(除了ts_code和trade_date)
        overlap_cols = list(set(result_df.columns) & set(moneyflow_df.columns) - {'ts_code', 'trade_date'})

        # 从moneyflow中排除这些重叠列
        moneyflow_df_filtered = moneyflow_df.drop(columns=overlap_cols)

        # 合并数据
        result_df = pd.merge(
            result_df,
            moneyflow_df_filtered,
            on=['ts_code', 'trade_date'],
            how='left'
        )

    # 将trade_date转换为datetime格式
    result_df['trade_date'] = pd.to_datetime(result_df['trade_date'])

    return result_df


def save_stock_data(df, engine, ts_code, if_exists='replace'):
    """
    保存股票数据到对应的表中

    Args:
        df: 股票数据DataFrame
        engine: 数据库引擎
        ts_code: 股票代码
        if_exists: 如果表已存在,处理方式。'replace': 替换现有表,'append': 附加到现有表

    Returns:
        int: 保存的记录数量
    """
    if df.empty:
        print(f"警告:{ts_code}没有数据可保存")
        return 0

    # 删除ts_code列,因为表中不需要
    if 'ts_code' in df.columns:
        df = df.drop(columns=['ts_code'])

    # 使用指定的模式保存数据
    try:
        symbol = ts_code.split('.')[0]
        result = df.to_sql(f'{symbol}', engine, index=False,
                           if_exists=if_exists, chunksize=1000)
        print(f"成功保存{result}条{ts_code}股票数据 (模式: {if_exists})")
        return result
    except Exception as e:
        print(f"保存{ts_code}数据时出错: {str(e)}")
        return 0


def split_list_into_chunks(data_list, num_chunks):
    """
    将列表分割成指定数量的块

    Args:
        data_list: 要分割的列表
        num_chunks: 块的数量

    Returns:
        list: 包含多个子列表的列表
    """
    avg = len(data_list) / float(num_chunks)
    chunks = []
    last = 0.0

    while last < len(data_list):
        chunks.append(data_list[int(last):int(last + avg)])
        last += avg

    return chunks


def process_stock_batch(batch, config, update_mode, start_date, end_date, update_time):
    """
    处理一批股票的工作函数(在单独进程中运行)

    Args:
        batch: 要处理的股票列表
        config: 配置信息
        update_mode: 更新模式 'full' 或 'incremental'
        start_date: 开始日期
        end_date: 结束日期
        update_time: 当前更新时间戳

    Returns:
        dict: 包含处理结果统计信息的字典
    """
    # 为当前进程创建新的数据库连接和API连接
    engine = create_engine_from_config(config)
    ts.set_token(config['tushare_token'])
    pro = ts.pro_api()

    # 结果统计
    results = {
        'success_count': 0,
        'failed_count': 0,
        'no_new_data_count': 0,
        'skipped_count': 0,
        'processed_count': 0
    }

    # 获取当前进程ID
    process_id = multiprocessing.current_process().name
    total_stocks = len(batch)

    for i, stock_info in enumerate(batch):
        # 根据更新模式获取不同的参数
        if update_mode == 'full':
            ts_code = stock_info
            try:
                # 简化日志,只显示进度
                print(f"[{process_id}] 进度 {i + 1}/{total_stocks}: 全量更新 {ts_code}")

                # 检查股票表是否存在,不存在则创建
                symbol = ts_code.split('.')[0]
                with engine.connect() as conn:
                    table_exists_query = f"""
                    SELECT COUNT(*)
                    FROM information_schema.tables
                    WHERE table_schema = '{config['mysql']['database']}'
                    AND table_name = '{symbol}'
                    """
                    table_exists = conn.execute(text(table_exists_query)).scalar() > 0

                if not table_exists:
                    create_stock_table(engine, ts_code)

                # 下载股票数据
                stock_data = download_stock_data(pro, ts_code, start_date, end_date)

                # 如果下载成功,保存数据并更新元数据
                if not stock_data.empty:
                    # 使用replace模式保存数据(全量替换)
                    records_saved = save_stock_data(stock_data, engine, ts_code)
                    # 更新元数据统计信息
                    update_metadata(engine, ts_code)
                    # 更新全量更新时间戳和状态为正常(1)
                    with engine.connect() as conn:
                        update_status_sql = f"""
                        UPDATE stock_metadata
                        SET last_full_update = '{update_time}',
                            status = 1
                        WHERE ts_code = '{ts_code}'
                        """
                        conn.execute(text(update_status_sql))
                        conn.commit()
                    results['success_count'] += 1
                else:
                    # 更新状态为全量更新失败(2)
                    with engine.connect() as conn:
                        update_status_sql = f"""
                        UPDATE stock_metadata
                        SET status = 2
                        WHERE ts_code = '{ts_code}'
                        """
                        conn.execute(text(update_status_sql))
                        conn.commit()
                    results['failed_count'] += 1

                # 防止API限流
                import time
                time.sleep(0.5)

            except Exception as e:
                # 发生异常时,更新状态为全量更新失败(2)
                try:
                    with engine.connect() as conn:
                        update_status_sql = f"""
                        UPDATE stock_metadata
                        SET status = 2
                        WHERE ts_code = '{ts_code}'
                        """
                        conn.execute(text(update_status_sql))
                        conn.commit()
                except Exception:
                    pass  # 更新状态失败时,忽略继续执行

                results['failed_count'] += 1
                print(f"[{process_id}] 股票 {ts_code} 更新失败: {str(e)}")

        elif update_mode == 'incremental':
            ts_code = stock_info['ts_code']
            start_date = stock_info['start_date']

            try:
                # 简化日志,只显示进度
                print(f"[{process_id}] 进度 {i + 1}/{total_stocks}: 增量更新 {ts_code}")

                # 确保股票表存在
                symbol = ts_code.split('.')[0]
                with engine.connect() as conn:
                    table_exists_query = f"""
                    SELECT COUNT(*)
                    FROM information_schema.tables
                    WHERE table_schema = '{config['mysql']['database']}'
                    AND table_name = '{symbol}'
                    """
                    table_exists = conn.execute(text(table_exists_query)).scalar() > 0

                if not table_exists:
                    results['skipped_count'] += 1
                    continue  # 表不存在则跳过

                # 下载增量数据
                new_data = download_stock_data(pro, ts_code, start_date, end_date)

                if not new_data.empty:
                    # 获取现有数据
                    try:
                        with engine.connect() as conn:
                            existing_data = pd.read_sql(f"SELECT * FROM `{symbol}`", conn)
                    except Exception as e:
                        results['skipped_count'] += 1
                        continue

                    if not existing_data.empty:
                        # 转换日期列为相同的格式以便合并
                        existing_data['trade_date'] = pd.to_datetime(existing_data['trade_date'])
                        new_data['trade_date'] = pd.to_datetime(new_data['trade_date'])

                        # 删除可能重复的日期记录
                        existing_dates = set(existing_data['trade_date'])
                        new_data = new_data[~new_data['trade_date'].isin(existing_dates)]

                        if new_data.empty:
                            results['no_new_data_count'] += 1
                            continue

                        # 合并数据
                        combined_data = pd.concat([existing_data, new_data], ignore_index=True)

                        # 按日期排序
                        combined_data = combined_data.sort_values('trade_date')

                        # 保存合并后的数据
                        records_saved = save_stock_data(combined_data, engine, ts_code)
                    else:
                        # 如果表存在但为空,直接保存新数据
                        records_saved = save_stock_data(new_data, engine, ts_code)

                    # 更新元数据统计信息
                    update_metadata(engine, ts_code)

                    # 更新增量更新时间戳和状态为正常(1)
                    with engine.connect() as conn:
                        update_status_sql = f"""
                        UPDATE stock_metadata
                        SET last_incremental_update = '{update_time}',
                            status = 1
                        WHERE ts_code = '{ts_code}'
                        """
                        conn.execute(text(update_status_sql))
                        conn.commit()

                    results['success_count'] += 1
                else:
                    # 无新数据
                    results['no_new_data_count'] += 1

                # 防止API限流
                import time
                time.sleep(0.5)

            except Exception as e:
                # 发生异常时,更新状态为增量更新失败(3)
                try:
                    with engine.connect() as conn:
                        update_status_sql = f"""
                        UPDATE stock_metadata
                        SET status = 3
                        WHERE ts_code = '{ts_code}'
                        """
                        conn.execute(text(update_status_sql))
                        conn.commit()
                except Exception:
                    pass  # 更新状态失败时,忽略继续执行

                results['failed_count'] += 1
                print(f"[{process_id}] 股票 {ts_code} 增量更新失败: {str(e)}")

        results['processed_count'] += 1

    return results


def perform_full_update(start_year=2020, processes=4, resume=False):
    """
    执行全量更新:从指定年份开始到今天的所有数据

    Args:
        start_year: 开始年份,默认为2020年
        processes: 进程数量,默认为4
        resume: 是否只更新之前全量更新失败的股票
    """
    update_type = "续传全量更新" if resume else "全量更新"
    print(f"开始执行{update_type} (从{start_year}年至今),使用 {processes} 个进程...")
    start_time = datetime.now()

    # 加载配置
    config = load_config()

    # 创建数据库引擎
    engine = create_engine_from_config(config)

    # 设置Tushare API
    ts.set_token(config['tushare_token'])
    pro = ts.pro_api()

    # 设置数据获取的日期范围
    end_date = datetime.now().strftime('%Y%m%d')
    start_date = f"{start_year}0101"  # 从指定年份的1月1日开始
    print(f"数据范围: {start_date} 至 {end_date}")

    # 从元数据表获取股票代码
    with engine.connect() as conn:
        # 根据resume参数决定查询条件
        if resume:
            # 只查询全量更新失败的股票
            query = """
            SELECT ts_code
            FROM stock_metadata
            WHERE list_status = 'L' AND status = 2
            """
        else:
            # 查询所有已上市的股票
            query = """
            SELECT ts_code
            FROM stock_metadata
            WHERE list_status = 'L'
            """

        result = conn.execute(text(query))
        stock_codes = [row[0] for row in result]

    if not stock_codes:
        if resume:
            print("没有找到需要续传的股票")
        else:
            print("没有找到需要更新的股票,请先确保元数据表已初始化")
        return

    print(f"共找到 {len(stock_codes)} 只股票需要{update_type}")

    # 当前更新时间
    update_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')

    # 将股票列表分成多个批次
    batches = split_list_into_chunks(stock_codes, processes)

    # 创建进程池
    with Pool(processes=processes) as pool:
        # 使用partial函数固定部分参数
        process_func = partial(
            process_stock_batch,
            config=config,
            update_mode='full',
            start_date=start_date,
            end_date=end_date,
            update_time=update_time
        )

        # 启动多进程处理并收集结果
        results = pool.map(process_func, batches)

    # 汇总所有进程的结果
    success_count = sum(r['success_count'] for r in results)
    failed_count = sum(r['failed_count'] for r in results)

    end_time = datetime.now()
    duration = (end_time - start_time).total_seconds() / 60  # 转换为分钟

    print(f"\n{update_type}完成!")
    print(f"总耗时: {duration:.2f} 分钟")
    print(f"成功更新: {success_count} 只股票")
    print(f"更新失败: {failed_count} 只股票")
    print(f"总计: {len(stock_codes)} 只股票")


def update_metadata(engine, ts_code):
    """
    更新股票元数据中的统计信息,从股票表中直接获取数据
    Args:
        engine: 数据库连接引擎
        ts_code: 股票代码(带后缀,如000001.SZ)
    Returns:
        bool: 更新成功返回True,否则返回False
    """
    try:
        # 提取股票代码(不含后缀)
        symbol = ts_code.split('.')[0]
        # 查询股票数据统计信息
        stats_query = f"""
        SELECT
            MIN(trade_date) as min_date,
            MAX(trade_date) as max_date,
            COUNT(*) as record_count
        FROM `{symbol}`
        """
        # 查询最新交易日数据
        latest_data_query = f"""
        SELECT
            close,
            pe_ttm,
            pb,
            total_mv
        FROM `{symbol}`
        WHERE trade_date = (SELECT MAX(trade_date) FROM `{symbol}`)
        LIMIT 1
        """
        # 执行查询
        with engine.connect() as conn:
            # 获取统计信息
            stats_result = conn.execute(text(stats_query)).fetchone()
            if not stats_result:
                # 更新状态为异常(2)
                update_empty_sql = f"""
                UPDATE stock_metadata
                SET last_full_update = NOW(),
                    record_count = 0,
                    status = 2
                WHERE ts_code = '{ts_code}'
                """
                conn.execute(text(update_empty_sql))
                conn.commit()
                return False

            # 处理日期字段
            try:
                data_start_date = stats_result[0].strftime('%Y-%m-%d') if stats_result[0] else 'NULL'
                data_end_date = stats_result[1].strftime('%Y-%m-%d') if stats_result[1] else 'NULL'
            except AttributeError:
                # 处理日期可能是字符串而不是datetime对象的情况
                data_start_date = stats_result[0] if stats_result[0] else 'NULL'
                data_end_date = stats_result[1] if stats_result[1] else 'NULL'

            record_count = stats_result[2] or 0

            # 获取最新交易日数据
            latest_data_result = conn.execute(text(latest_data_query)).fetchone()

            # 设置默认值并处理NULL
            default_value = 'NULL'
            latest_price = default_value
            latest_pe_ttm = default_value
            latest_pb = default_value
            latest_total_mv = default_value

            # 如果有最新数据,则更新相应字段
            if latest_data_result:
                latest_price = str(latest_data_result[0]) if latest_data_result[0] is not None else default_value
                latest_pe_ttm = str(latest_data_result[1]) if latest_data_result[1] is not None else default_value
                latest_pb = str(latest_data_result[2]) if latest_data_result[2] is not None else default_value
                latest_total_mv = str(latest_data_result[3]) if latest_data_result[3] is not None else default_value

            # 构建更新SQL
            update_sql = f"""
            UPDATE stock_metadata
            SET
                data_start_date = CASE WHEN '{data_start_date}' = 'NULL' THEN NULL ELSE '{data_start_date}' END,
                data_end_date = CASE WHEN '{data_end_date}' = 'NULL' THEN NULL ELSE '{data_end_date}' END,
                record_count = {record_count},
                status = CASE WHEN {record_count} > 0 THEN 1 ELSE 2 END,
                latest_price = CASE WHEN '{latest_price}' = 'NULL' THEN NULL ELSE {latest_price} END,
                latest_date = CASE WHEN '{data_end_date}' = 'NULL' THEN NULL ELSE '{data_end_date}' END,
                latest_pe_ttm = CASE WHEN '{latest_pe_ttm}' = 'NULL' THEN NULL ELSE {latest_pe_ttm} END,
                latest_pb = CASE WHEN '{latest_pb}' = 'NULL' THEN NULL ELSE {latest_pb} END,
                latest_total_mv = CASE WHEN '{latest_total_mv}' = 'NULL' THEN NULL ELSE {latest_total_mv} END
            WHERE ts_code = '{ts_code}'
            """
            # 执行更新
            conn.execute(text(update_sql))
            conn.commit()

        return True
    except Exception as e:
        # 尝试将状态更新为元数据更新失败(-1)
        try:
            with engine.connect() as conn:
                error_update_sql = f"""
                UPDATE stock_metadata
                SET status = -1
                WHERE ts_code = '{ts_code}'
                """
                conn.execute(text(error_update_sql))
                conn.commit()
        except Exception:
            pass  # 更新状态失败时,忽略继续执行

        print(f"更新({ts_code})元数据时出错: {str(e)}")
        return False


def perform_incremental_update(processes=4):
    """
    执行增量更新:从上次增量或全量更新日期到今天的数据

    Args:
        processes: 进程数量,默认为4
    """
    print(f"开始执行增量更新,使用 {processes} 个进程...")
    start_time = datetime.now()

    # 加载配置
    config = load_config()

    # 创建数据库引擎
    engine = create_engine_from_config(config)

    # 设置Tushare API
    ts.set_token(config['tushare_token'])
    pro = ts.pro_api()

    # 当前日期作为结束日期
    end_date = datetime.now().strftime('%Y%m%d')

    # 获取需要更新的股票及其上次更新日期
    stocks_to_update = []
    with engine.connect() as conn:
        query = """
        SELECT
            ts_code,
            last_incremental_update,
            last_full_update
        FROM
            stock_metadata
        WHERE
            list_status = 'L'
        """
        result = conn.execute(text(query))
        for row in result:
            ts_code = row[0]
            last_incr_update = row[1]  # 上次增量更新日期
            last_full_update = row[2]  # 上次全量更新日期

            # 确定起始日期:使用最近的更新日期
            latest_update = None

            # 检查增量更新日期
            if last_incr_update:
                latest_update = last_incr_update

            # 检查全量更新日期
            if last_full_update:
                if not latest_update or last_full_update > latest_update:
                    latest_update = last_full_update

            # 如果有更新日期,使用其后一天作为起始日期
            if latest_update:
                start_date = (latest_update + timedelta(days=1)).strftime('%Y%m%d')
            else:
                # 如果没有任何更新记录,默认从30天前开始
                start_date = (datetime.now() - timedelta(days=30)).strftime('%Y%m%d')

            # 如果起始日期大于等于结束日期,则跳过此股票
            if start_date >= end_date:
                continue

            stocks_to_update.append({
                'ts_code': ts_code,
                'start_date': start_date
            })

    if not stocks_to_update:
        print("没有找到需要增量更新的股票")
        return

    print(f"共找到 {len(stocks_to_update)} 只股票需要增量更新")

    # 当前更新时间
    update_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')

    # 将股票列表分成多个批次
    batches = split_list_into_chunks(stocks_to_update, processes)

    # 创建进程池
    with Pool(processes=processes) as pool:
        # 使用partial函数固定部分参数
        process_func = partial(
            process_stock_batch,
            config=config,
            update_mode='incremental',
            start_date=None,  # 增量模式下使用每个股票自己的起始日期
            end_date=end_date,
            update_time=update_time
        )

        # 启动多进程处理并收集结果
        results = pool.map(process_func, batches)

    # 汇总所有进程的结果
    success_count = sum(r['success_count'] for r in results)
    failed_count = sum(r['failed_count'] for r in results)
    no_new_data_count = sum(r['no_new_data_count'] for r in results)
    skipped_count = sum(r['skipped_count'] for r in results)

    end_time = datetime.now()
    duration = (end_time - start_time).total_seconds() / 60  # 转换为分钟

    print("\n增量更新完成!")
    print(f"总耗时: {duration:.2f} 分钟")
    print(f"成功更新: {success_count} 只股票")
    print(f"无新数据: {no_new_data_count} 只股票")
    print(f"跳过股票: {skipped_count} 只股票")
    print(f"更新失败: {failed_count} 只股票")
    print(f"总计: {len(stocks_to_update)} 只股票")


def main():
    """主函数,允许用户选择更新方式"""
    import argparse

    # 命令行参数解析
    parser = argparse.ArgumentParser(description='股票数据更新工具')
    parser.add_argument('--init', action='store_true', help='初始化元数据表')
    parser.add_argument('--mode', choices=['full', 'incremental', 'both'],
                        default='full', help='更新模式: full=全量更新, incremental=增量更新, both=两者都执行')
    parser.add_argument('--year', type=int, default=2020, help='全量更新的起始年份')
    parser.add_argument('--processes', type=int, default=8, help='使用的进程数量')
    parser.add_argument('--resume', action='store_true', help='仅更新全量更新失败的股票,即status=2部分')

    args = parser.parse_args()

    # 如果需要初始化元数据
    if args.init:
        # 加载配置
        config = load_config()
        # 创建数据库引擎
        engine = create_engine_from_config(config)
        # 设置Tushare API
        ts.set_token(config['tushare_token'])
        pro = ts.pro_api()

        # 创建元数据表
        create_metadata_table(engine)
        # 下载并保存股票元数据
        metadata_df = download_stock_metadata(pro)
        save_metadata(metadata_df, engine)
        print("元数据初始化完成")

    # 根据选择的模式执行更新
    if args.mode == 'full' or args.mode == 'both':
        perform_full_update(args.year, args.processes, args.resume)

    if args.mode == 'incremental' or args.mode == 'both':
        perform_incremental_update(args.processes)


if __name__ == '__main__':
    main()