import pandas as pd import tushare as ts import numpy as np import os from sqlalchemy import create_engine, text from datetime import datetime, timedelta from utils import load_config def create_engine_from_config(config): mysql = config['mysql'] connection_string = f"mysql+pymysql://{mysql['user']}:{mysql['password']}@{mysql['host']}:{mysql['port']}/{mysql['database']}?charset={mysql['charset']}&use_unicode=1" return create_engine(connection_string) def create_metadata_table(engine): """创建股票元数据表""" # 从sql文件读取创建表的SQL with open('sql/meta.sql', 'r', encoding='utf-8') as f: create_table_sql = f.read() with engine.connect() as conn: conn.execute(text(create_table_sql)) conn.commit() print("股票元数据表创建成功") def create_stock_table(engine, ts_code): """为指定股票创建数据表""" # 确保sql目录存在 os.makedirs('sql', exist_ok=True) # 读取模板SQL with open('sql/stock_table_template.sql', 'r', encoding='utf-8') as f: template_sql = f.read() # 替换模板中的symbol create_table_sql = template_sql.format(symbol=ts_code.split('.')[0]) # 执行创建表的SQL with engine.connect() as conn: conn.execute(text(create_table_sql)) conn.commit() print(f"股票({ts_code})数据表创建成功") def download_stock_metadata(pro): """下载股票基础信息""" df = pro.stock_basic(exchange='', list_status='L', fields='ts_code,symbol,name,area,industry,fullname,enname,cnspell,market,exchange,curr_type,list_status,list_date,delist_date,is_hs,act_name,act_ent_type') # 添加状态字段 df['status'] = 0 # 未初始化 df['remark'] = '自动导入' df['last_full_update'] = None df['last_incremental_update'] = None df['data_start_date'] = None df['data_end_date'] = None df['record_count'] = 0 df['latest_price'] = None df['latest_date'] = None df['latest_pe_ttm'] = None df['latest_pb'] = None df['latest_total_mv'] = None return df def save_metadata(df, engine): """保存元数据到数据库,只更新指定字段发生变化的记录,处理日期格式问题""" # 需要监控变化的基础元数据字段 metadata_fields = [ 'ts_code', 'symbol', 'name', 'area', 'industry', 'fullname', 'enname', 'cnspell', 'market', 'exchange', 'curr_type', 'list_status', 'list_date', 'delist_date', 'is_hs', 'act_name', 'act_ent_type' ] # 日期字段列表 date_fields = ['list_date', 'delist_date'] # 从数据库中读取现有的元数据 try: with engine.connect() as conn: # 只查询我们关心的字段 query = f"SELECT {', '.join(metadata_fields)} FROM stock_metadata" existing_df = pd.read_sql(query, conn) except Exception as e: print(f"读取现有元数据时出错: {str(e)}") existing_df = pd.DataFrame() # 表不存在时创建空DataFrame if existing_df.empty: # 如果表是空的,直接插入所有数据 result = df.to_sql('stock_metadata', engine, index=False, if_exists='append', chunksize=1000) print(f"成功保存{result}条股票元数据(新建)") return result else: # 数据预处理:处理日期格式 df_processed = df.copy() existing_processed = existing_df.copy() # 标准化日期格式进行比较 def normalize_date(date_str): if pd.isna(date_str) or date_str == '' or date_str == 'None' or date_str is None: return '' # 去除所有非数字字符 date_str = str(date_str).strip() digits_only = ''.join(c for c in date_str if c.isdigit()) # 如果是8位数字形式的日期,返回标准格式 if len(digits_only) == 8: return digits_only return '' # 对两个DataFrame应用日期标准化 for field in date_fields: if field in df_processed.columns and field in existing_processed.columns: df_processed[field] = df_processed[field].apply(normalize_date) existing_processed[field] = existing_processed[field].apply(normalize_date) # 对其他非日期字段进行标准化处理 for col in [f for f in metadata_fields if f not in date_fields and f != 'ts_code']: if col in df_processed.columns and col in existing_processed.columns: # 将两个DataFrame的相同列转换为相同的数据类型(字符串类型) df_processed[col] = df_processed[col].astype(str) existing_processed[col] = existing_processed[col].astype(str) # 处理NaN值 df_processed[col] = df_processed[col].fillna('').str.strip() existing_processed[col] = existing_processed[col].fillna('').str.strip() # 找出新增的记录 existing_ts_codes = set(existing_processed['ts_code']) new_records = df[~df['ts_code'].isin(existing_ts_codes)].copy() # 找出需要更新的记录 changed_records = [] unchanged_count = 0 # 对于每个已存在的ts_code,检查是否有变化 for ts_code in df_processed[df_processed['ts_code'].isin(existing_ts_codes)]['ts_code']: # 获取新旧记录(已经标准化后的) new_record = df_processed[df_processed['ts_code'] == ts_code].iloc[0] old_record = existing_processed[existing_processed['ts_code'] == ts_code].iloc[0] # 同时取原始记录用于更新操作 original_record = df[df['ts_code'] == ts_code].iloc[0] # 检查是否有变化 has_change = False changed_fields = [] for field in metadata_fields: if field == 'ts_code': # 跳过主键 continue new_val = new_record[field] old_val = old_record[field] if new_val != old_val: has_change = True changed_fields.append(field) print(f"发现变化 - {ts_code} 的 {field}: '{old_val}' -> '{new_val}'") if has_change: changed_records.append({ 'record': original_record, # 使用原始记录,保持原始格式 'changed_fields': changed_fields }) else: unchanged_count += 1 # 插入新记录 new_count = 0 if not new_records.empty: new_count = new_records.to_sql('stock_metadata', engine, index=False, if_exists='append', chunksize=1000) # 更新变化的记录 updated_count = 0 for change_info in changed_records: record = change_info['record'] fields = change_info['changed_fields'] # 构建更新语句,只更新变化的字段 fields_to_update = [] params = {'ts_code': record['ts_code']} for field in fields: fields_to_update.append(f"{field} = :{field}") params[field] = record[field] if fields_to_update: update_stmt = text(f""" UPDATE stock_metadata SET {', '.join(fields_to_update)} WHERE ts_code = :ts_code """) with engine.connect() as conn: conn.execute(update_stmt, params) updated_count += 1 conn.commit() print(f"元数据更新统计:") print(f" • 新增记录: {new_count}条") print(f" • 更新记录: {updated_count}条") print(f" • 无变化记录: {unchanged_count}条") print(f" • 总处理记录: {new_count + updated_count + unchanged_count}条") return new_count + updated_count def download_stock_data(pro, ts_code, start_date, end_date): """下载股票的所有类型数据并合并""" # 下载daily价格数据 daily_df = pro.daily(ts_code=ts_code, start_date=start_date, end_date=end_date) if daily_df.empty: print(f"警告:{ts_code}没有找到daily数据") return pd.DataFrame() # 下载daily_basic数据 daily_basic_df = pro.daily_basic(ts_code=ts_code, start_date=start_date, end_date=end_date) # 下载moneyflow数据 moneyflow_df = pro.moneyflow(ts_code=ts_code, start_date=start_date, end_date=end_date) # 确保每个DataFrame都有trade_date列作为合并键 if 'trade_date' not in daily_df.columns: print(f"错误:{ts_code}的daily数据缺少trade_date列") return pd.DataFrame() # 为方便处理,确保所有日期列是字符串类型 daily_df['trade_date'] = daily_df['trade_date'].astype(str) # 通过merge而不是join合并数据,这样可以更好地控制列 result_df = daily_df # 合并daily_basic数据 if not daily_basic_df.empty: daily_basic_df['trade_date'] = daily_basic_df['trade_date'].astype(str) # 识别重叠的列(除了ts_code和trade_date) overlap_cols = list(set(result_df.columns) & set(daily_basic_df.columns) - {'ts_code', 'trade_date'}) # 从daily_basic中排除这些重叠列 daily_basic_df_filtered = daily_basic_df.drop(columns=overlap_cols) # 合并数据 result_df = pd.merge( result_df, daily_basic_df_filtered, on=['ts_code', 'trade_date'], how='left' ) # 合并moneyflow数据 if not moneyflow_df.empty: moneyflow_df['trade_date'] = moneyflow_df['trade_date'].astype(str) # 识别重叠的列(除了ts_code和trade_date) overlap_cols = list(set(result_df.columns) & set(moneyflow_df.columns) - {'ts_code', 'trade_date'}) # 从moneyflow中排除这些重叠列 moneyflow_df_filtered = moneyflow_df.drop(columns=overlap_cols) # 合并数据 result_df = pd.merge( result_df, moneyflow_df_filtered, on=['ts_code', 'trade_date'], how='left' ) # 将trade_date转换为datetime格式 result_df['trade_date'] = pd.to_datetime(result_df['trade_date']) return result_df def save_stock_data(df, engine, ts_code): """保存股票数据到对应的表中""" if df.empty: print(f"警告:{ts_code}没有数据可保存") return 0 # 删除ts_code列,因为表中不需要 if 'ts_code' in df.columns: df = df.drop(columns=['ts_code']) # 使用replace模式,确保表中只包含最新数据 try: symbol = ts_code.split('.')[0] result = df.to_sql(f'{symbol}', engine, index=False, if_exists='replace', chunksize=1000) print(f"成功保存{result}条{ts_code}股票数据") return result except Exception as e: print(f"保存{ts_code}数据时出错: {str(e)}") return 0 def update_metadata(engine, ts_code): """更新股票元数据中的统计信息,从股票表中直接获取数据""" # 提取股票代码(不含后缀) symbol = ts_code.split('.')[0] # 查询股票数据统计信息 stats_query = f""" SELECT MIN(trade_date) as min_date, MAX(trade_date) as max_date, COUNT(*) as record_count FROM `{symbol}` """ # 查询最新交易日数据 latest_data_query = f""" SELECT close, pe_ttm, pb, total_mv FROM `{symbol}` WHERE trade_date = (SELECT MAX(trade_date) FROM `{symbol}`) LIMIT 1 """ # 执行查询 with engine.connect() as conn: # 获取统计信息 stats_result = conn.execute(text(stats_query)).fetchone() if not stats_result: print(f"警告:{ts_code}没有数据,无法更新元数据") return data_start_date = stats_result[0].strftime('%Y-%m-%d') if stats_result[0] else 'NULL' data_end_date = stats_result[1].strftime('%Y-%m-%d') if stats_result[1] else 'NULL' record_count = stats_result[2] # 获取最新交易日数据 latest_data_result = conn.execute(text(latest_data_query)).fetchone() latest_price = f"{latest_data_result[0]}" if latest_data_result and latest_data_result[ 0] is not None else 'NULL' latest_pe_ttm = f"{latest_data_result[1]}" if latest_data_result and latest_data_result[ 1] is not None else 'NULL' latest_pb = f"{latest_data_result[2]}" if latest_data_result and latest_data_result[2] is not None else 'NULL' latest_total_mv = f"{latest_data_result[3]}" if latest_data_result and latest_data_result[ 3] is not None else 'NULL' # 加载更新SQL模板 with open('sql/update_metadata.sql', 'r', encoding='utf-8') as f: update_sql_template = f.read() # 填充模板 update_sql = update_sql_template.format( data_start_date=data_start_date, data_end_date=data_end_date, record_count=record_count, latest_price=latest_price, latest_pe_ttm=latest_pe_ttm, latest_pb=latest_pb, latest_total_mv=latest_total_mv, ts_code=ts_code ) # 执行更新 conn.execute(text(update_sql)) conn.commit() print(f"已更新({ts_code})的元数据信息") def process_stock(engine, pro, ts_code, start_date, end_date): """处理单个股票的完整流程""" # 创建该股票的表 create_stock_table(engine, ts_code) # 下载该股票的数据 stock_data = download_stock_data(pro, ts_code, start_date, end_date) # 保存股票数据 if not stock_data.empty: save_stock_data(stock_data, engine, ts_code) update_metadata(engine, ts_code) else: print(f"警告:({ts_code})没有获取到数据") def main(): # 加载配置 config = load_config() # 创建数据库引擎 engine = create_engine_from_config(config) # 设置Tushare API ts.set_token(config['tushare_token']) pro = ts.pro_api() # 创建元数据表 create_metadata_table(engine) # 下载并保存股票元数据 metadata_df = download_stock_metadata(pro) save_metadata(metadata_df, engine) # 示例:处理000001股票 ts_code = '000001.SZ' # 平安银行的ts_code # 设置数据获取的日期范围(示例:过去5年的数据) end_date = datetime.now().strftime('%Y%m%d') start_date = (datetime.now() - timedelta(days=15 * 365)).strftime('%Y%m%d') # 处理该股票 process_stock(engine, pro, ts_code, start_date, end_date) if __name__ == '__main__': main()