1055 lines
38 KiB
Python
1055 lines
38 KiB
Python
import argparse
|
||
|
||
import pandas as pd
|
||
import tushare as ts
|
||
import numpy as np
|
||
import os
|
||
import multiprocessing
|
||
from multiprocessing import Pool
|
||
from functools import partial
|
||
from sqlalchemy import create_engine, text
|
||
from datetime import datetime, timedelta
|
||
|
||
from utils import load_config, create_engine_from_config
|
||
|
||
|
||
def create_metadata_table(engine):
|
||
"""创建股票元数据表"""
|
||
# 从sql文件读取创建表的SQL
|
||
with open('sql/meta.sql', 'r', encoding='utf-8') as f:
|
||
create_table_sql = f.read()
|
||
|
||
with engine.connect() as conn:
|
||
conn.execute(text(create_table_sql))
|
||
conn.commit()
|
||
print("股票元数据表创建成功")
|
||
|
||
|
||
def create_stock_table(engine, ts_code):
|
||
"""为指定股票创建数据表"""
|
||
# 确保sql目录存在
|
||
os.makedirs('sql', exist_ok=True)
|
||
|
||
# 读取模板SQL
|
||
with open('sql/stock_table_template.sql', 'r', encoding='utf-8') as f:
|
||
template_sql = f.read()
|
||
|
||
# 替换模板中的symbol
|
||
create_table_sql = template_sql.format(symbol=ts_code.split('.')[0])
|
||
|
||
# 执行创建表的SQL
|
||
with engine.connect() as conn:
|
||
conn.execute(text(create_table_sql))
|
||
conn.commit()
|
||
|
||
print(f"股票({ts_code})数据表创建成功")
|
||
|
||
|
||
def download_stock_metadata(pro):
|
||
"""下载股票基础信息"""
|
||
df = pro.stock_basic(exchange='', list_status='L',
|
||
fields='ts_code,symbol,name,area,industry,fullname,enname,cnspell,market,exchange,curr_type,list_status,list_date,delist_date,is_hs,act_name,act_ent_type')
|
||
# 添加状态字段
|
||
df['status'] = 0 # 未初始化
|
||
df['remark'] = '自动导入'
|
||
df['last_full_update'] = None
|
||
df['last_incremental_update'] = None
|
||
df['data_start_date'] = None
|
||
df['data_end_date'] = None
|
||
df['record_count'] = 0
|
||
df['latest_price'] = None
|
||
df['latest_date'] = None
|
||
df['latest_pe_ttm'] = None
|
||
df['latest_pb'] = None
|
||
df['latest_total_mv'] = None
|
||
|
||
return df
|
||
|
||
|
||
def save_metadata(df, engine):
|
||
"""保存元数据到数据库,只更新指定字段发生变化的记录,处理日期格式问题"""
|
||
# 需要监控变化的基础元数据字段
|
||
metadata_fields = [
|
||
'ts_code', 'symbol', 'name', 'area', 'industry', 'fullname',
|
||
'enname', 'cnspell', 'market', 'exchange', 'curr_type',
|
||
'list_status', 'list_date', 'delist_date', 'is_hs',
|
||
'act_name', 'act_ent_type'
|
||
]
|
||
|
||
# 日期字段列表
|
||
date_fields = ['list_date', 'delist_date']
|
||
|
||
# 从数据库中读取现有的元数据
|
||
try:
|
||
with engine.connect() as conn:
|
||
# 只查询我们关心的字段
|
||
query = f"SELECT {', '.join(metadata_fields)} FROM stock_metadata"
|
||
existing_df = pd.read_sql(query, conn)
|
||
except Exception as e:
|
||
print(f"读取现有元数据时出错: {str(e)}")
|
||
existing_df = pd.DataFrame() # 表不存在时创建空DataFrame
|
||
|
||
if existing_df.empty:
|
||
# 如果表是空的,直接插入所有数据
|
||
result = df.to_sql('stock_metadata', engine, index=False,
|
||
if_exists='append', chunksize=1000)
|
||
print(f"成功保存{result}条股票元数据(新建)")
|
||
return result
|
||
else:
|
||
# 数据预处理:处理日期格式
|
||
df_processed = df.copy()
|
||
existing_processed = existing_df.copy()
|
||
|
||
# 标准化日期格式进行比较
|
||
def normalize_date(date_str):
|
||
if pd.isna(date_str) or date_str == '' or date_str == 'None' or date_str is None:
|
||
return ''
|
||
|
||
# 去除所有非数字字符
|
||
date_str = str(date_str).strip()
|
||
digits_only = ''.join(c for c in date_str if c.isdigit())
|
||
|
||
# 如果是8位数字形式的日期,返回标准格式
|
||
if len(digits_only) == 8:
|
||
return digits_only
|
||
return ''
|
||
|
||
# 对两个DataFrame应用日期标准化
|
||
for field in date_fields:
|
||
if field in df_processed.columns and field in existing_processed.columns:
|
||
df_processed[field] = df_processed[field].apply(normalize_date)
|
||
existing_processed[field] = existing_processed[field].apply(normalize_date)
|
||
|
||
# 对其他非日期字段进行标准化处理
|
||
for col in [f for f in metadata_fields if f not in date_fields and f != 'ts_code']:
|
||
if col in df_processed.columns and col in existing_processed.columns:
|
||
# 将两个DataFrame的相同列转换为相同的数据类型(字符串类型)
|
||
df_processed[col] = df_processed[col].astype(str)
|
||
existing_processed[col] = existing_processed[col].astype(str)
|
||
|
||
# 处理NaN值
|
||
df_processed[col] = df_processed[col].fillna('').str.strip()
|
||
existing_processed[col] = existing_processed[col].fillna('').str.strip()
|
||
|
||
# 找出新增的记录
|
||
existing_ts_codes = set(existing_processed['ts_code'])
|
||
new_records = df[~df['ts_code'].isin(existing_ts_codes)].copy()
|
||
|
||
# 找出需要更新的记录
|
||
changed_records = []
|
||
unchanged_count = 0
|
||
|
||
# 对于每个已存在的ts_code,检查是否有变化
|
||
for ts_code in df_processed[df_processed['ts_code'].isin(existing_ts_codes)]['ts_code']:
|
||
# 获取新旧记录(已经标准化后的)
|
||
new_record = df_processed[df_processed['ts_code'] == ts_code].iloc[0]
|
||
old_record = existing_processed[existing_processed['ts_code'] == ts_code].iloc[0]
|
||
|
||
# 同时取原始记录用于更新操作
|
||
original_record = df[df['ts_code'] == ts_code].iloc[0]
|
||
|
||
# 检查是否有变化
|
||
has_change = False
|
||
changed_fields = []
|
||
|
||
for field in metadata_fields:
|
||
if field == 'ts_code': # 跳过主键
|
||
continue
|
||
|
||
new_val = new_record[field]
|
||
old_val = old_record[field]
|
||
|
||
if new_val != old_val:
|
||
has_change = True
|
||
changed_fields.append(field)
|
||
print(f"发现变化 - {ts_code} 的 {field}: '{old_val}' -> '{new_val}'")
|
||
|
||
if has_change:
|
||
changed_records.append({
|
||
'record': original_record, # 使用原始记录,保持原始格式
|
||
'changed_fields': changed_fields
|
||
})
|
||
else:
|
||
unchanged_count += 1
|
||
|
||
# 插入新记录
|
||
new_count = 0
|
||
if not new_records.empty:
|
||
new_count = new_records.to_sql('stock_metadata', engine, index=False,
|
||
if_exists='append', chunksize=1000)
|
||
|
||
# 更新变化的记录
|
||
updated_count = 0
|
||
for change_info in changed_records:
|
||
record = change_info['record']
|
||
fields = change_info['changed_fields']
|
||
|
||
# 构建更新语句,只更新变化的字段
|
||
fields_to_update = []
|
||
params = {'ts_code': record['ts_code']}
|
||
|
||
for field in fields:
|
||
fields_to_update.append(f"{field} = :{field}")
|
||
params[field] = record[field]
|
||
|
||
if fields_to_update:
|
||
update_stmt = text(f"""
|
||
UPDATE stock_metadata
|
||
SET {', '.join(fields_to_update)}
|
||
WHERE ts_code = :ts_code
|
||
""")
|
||
|
||
with engine.connect() as conn:
|
||
conn.execute(update_stmt, params)
|
||
updated_count += 1
|
||
conn.commit()
|
||
|
||
print(f"元数据更新统计:")
|
||
print(f" • 新增记录: {new_count}条")
|
||
print(f" • 更新记录: {updated_count}条")
|
||
print(f" • 无变化记录: {unchanged_count}条")
|
||
print(f" • 总处理记录: {new_count + updated_count + unchanged_count}条")
|
||
|
||
return new_count + updated_count
|
||
|
||
|
||
def download_stock_data(pro, ts_code, start_date, end_date):
|
||
"""下载股票的所有类型数据并合并"""
|
||
# 下载daily价格数据
|
||
daily_df = pro.daily(ts_code=ts_code, start_date=start_date, end_date=end_date)
|
||
|
||
# 下载daily_basic数据
|
||
daily_basic_df = pro.daily_basic(ts_code=ts_code, start_date=start_date, end_date=end_date)
|
||
|
||
# 下载moneyflow数据
|
||
moneyflow_df = pro.moneyflow(ts_code=ts_code, start_date=start_date, end_date=end_date)
|
||
|
||
if daily_df.empty or daily_basic_df.empty or moneyflow_df.empty:
|
||
print(f"警告:{ts_code}daily数据不全,无法合并")
|
||
return pd.DataFrame()
|
||
|
||
# 确保每个DataFrame都有trade_date列作为合并键
|
||
if 'trade_date' not in daily_df.columns:
|
||
print(f"错误:{ts_code}的daily数据缺少trade_date列")
|
||
return pd.DataFrame()
|
||
|
||
# 为方便处理,确保所有日期列是字符串类型
|
||
daily_df['trade_date'] = daily_df['trade_date'].astype(str)
|
||
|
||
# 通过merge而不是join合并数据,这样可以更好地控制列
|
||
result_df = daily_df
|
||
|
||
# 合并daily_basic数据
|
||
if not daily_basic_df.empty:
|
||
daily_basic_df['trade_date'] = daily_basic_df['trade_date'].astype(str)
|
||
|
||
# 识别重叠的列(除了ts_code和trade_date)
|
||
overlap_cols = list(set(result_df.columns) & set(daily_basic_df.columns) - {'ts_code', 'trade_date'})
|
||
|
||
# 从daily_basic中排除这些重叠列
|
||
daily_basic_df_filtered = daily_basic_df.drop(columns=overlap_cols)
|
||
|
||
# 合并数据
|
||
result_df = pd.merge(
|
||
result_df,
|
||
daily_basic_df_filtered,
|
||
on=['ts_code', 'trade_date'],
|
||
how='left'
|
||
)
|
||
|
||
# 合并moneyflow数据
|
||
if not moneyflow_df.empty:
|
||
moneyflow_df['trade_date'] = moneyflow_df['trade_date'].astype(str)
|
||
|
||
# 识别重叠的列(除了ts_code和trade_date)
|
||
overlap_cols = list(set(result_df.columns) & set(moneyflow_df.columns) - {'ts_code', 'trade_date'})
|
||
|
||
# 从moneyflow中排除这些重叠列
|
||
moneyflow_df_filtered = moneyflow_df.drop(columns=overlap_cols)
|
||
|
||
# 合并数据
|
||
result_df = pd.merge(
|
||
result_df,
|
||
moneyflow_df_filtered,
|
||
on=['ts_code', 'trade_date'],
|
||
how='left'
|
||
)
|
||
|
||
# 将trade_date转换为datetime格式
|
||
result_df['trade_date'] = pd.to_datetime(result_df['trade_date'])
|
||
|
||
return result_df
|
||
|
||
|
||
def save_stock_data(df, engine, ts_code, if_exists='replace'):
|
||
"""
|
||
保存股票数据到对应的表中
|
||
|
||
Args:
|
||
df: 股票数据DataFrame
|
||
engine: 数据库引擎
|
||
ts_code: 股票代码
|
||
if_exists: 如果表已存在,处理方式。'replace': 替换现有表,'append': 附加到现有表
|
||
|
||
Returns:
|
||
int: 保存的记录数量
|
||
"""
|
||
if df.empty:
|
||
print(f"警告:{ts_code}没有数据可保存")
|
||
return 0
|
||
|
||
# 删除ts_code列,因为表中不需要
|
||
if 'ts_code' in df.columns:
|
||
df = df.drop(columns=['ts_code'])
|
||
|
||
# 使用指定的模式保存数据
|
||
try:
|
||
symbol = ts_code.split('.')[0]
|
||
result = df.to_sql(f'{symbol}', engine, index=False,
|
||
if_exists=if_exists, chunksize=1000)
|
||
print(f"成功保存{result}条{ts_code}股票数据 (模式: {if_exists})")
|
||
return result
|
||
except Exception as e:
|
||
print(f"保存{ts_code}数据时出错: {str(e)}")
|
||
return 0
|
||
|
||
|
||
def split_list_into_chunks(data_list, num_chunks):
|
||
"""
|
||
将列表分割成指定数量的块
|
||
|
||
Args:
|
||
data_list: 要分割的列表
|
||
num_chunks: 块的数量
|
||
|
||
Returns:
|
||
list: 包含多个子列表的列表
|
||
"""
|
||
avg = len(data_list) / float(num_chunks)
|
||
chunks = []
|
||
last = 0.0
|
||
|
||
while last < len(data_list):
|
||
chunks.append(data_list[int(last):int(last + avg)])
|
||
last += avg
|
||
|
||
return chunks
|
||
|
||
|
||
def process_stock_batch(batch, config, update_mode, start_date, end_date):
|
||
"""
|
||
处理一批股票的工作函数(在单独进程中运行)
|
||
|
||
Args:
|
||
batch: 要处理的股票列表
|
||
config: 配置信息
|
||
update_mode: 更新模式 'full' 或 'incremental'
|
||
start_date: 开始日期
|
||
end_date: 结束日期
|
||
|
||
Returns:
|
||
dict: 包含处理结果统计信息的字典
|
||
"""
|
||
# 为当前进程创建新的数据库连接和API连接
|
||
engine = create_engine_from_config(config)
|
||
ts.set_token(config['tushare_token'])
|
||
pro = ts.pro_api()
|
||
|
||
# 结果统计
|
||
results = {
|
||
'success_count': 0,
|
||
'failed_count': 0,
|
||
'no_new_data_count': 0,
|
||
'skipped_count': 0,
|
||
'processed_count': 0
|
||
}
|
||
|
||
# 获取当前进程ID
|
||
process_id = multiprocessing.current_process().name
|
||
total_stocks = len(batch)
|
||
|
||
for i, stock_info in enumerate(batch):
|
||
# 当前更新时间
|
||
update_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||
# 根据更新模式获取不同的参数
|
||
if update_mode == 'full':
|
||
ts_code = stock_info
|
||
try:
|
||
# 简化日志,只显示进度
|
||
print(f"[{process_id}] 进度 {i + 1}/{total_stocks}: 全量更新 {ts_code}")
|
||
|
||
# 检查股票表是否存在,不存在则创建
|
||
symbol = ts_code.split('.')[0]
|
||
with engine.connect() as conn:
|
||
table_exists_query = f"""
|
||
SELECT COUNT(*)
|
||
FROM information_schema.tables
|
||
WHERE table_schema = '{config['mysql']['database']}'
|
||
AND table_name = '{symbol}'
|
||
"""
|
||
table_exists = conn.execute(text(table_exists_query)).scalar() > 0
|
||
|
||
if not table_exists:
|
||
create_stock_table(engine, ts_code)
|
||
|
||
# 下载股票数据
|
||
stock_data = download_stock_data(pro, ts_code, start_date, end_date)
|
||
|
||
# 如果下载成功,保存数据并更新元数据
|
||
if not stock_data.empty:
|
||
# 使用replace模式保存数据(全量替换)
|
||
records_saved = save_stock_data(stock_data, engine, ts_code)
|
||
# 更新元数据统计信息
|
||
update_metadata(engine, ts_code)
|
||
# 更新全量更新时间戳和状态为正常(1)
|
||
with engine.connect() as conn:
|
||
update_status_sql = f"""
|
||
UPDATE stock_metadata
|
||
SET last_full_update = '{update_time}',
|
||
status = 1
|
||
WHERE ts_code = '{ts_code}'
|
||
"""
|
||
conn.execute(text(update_status_sql))
|
||
conn.commit()
|
||
results['success_count'] += 1
|
||
else:
|
||
# 更新状态为全量更新失败(2)
|
||
with engine.connect() as conn:
|
||
update_status_sql = f"""
|
||
UPDATE stock_metadata
|
||
SET status = 2
|
||
WHERE ts_code = '{ts_code}'
|
||
"""
|
||
conn.execute(text(update_status_sql))
|
||
conn.commit()
|
||
results['failed_count'] += 1
|
||
|
||
# 防止API限流
|
||
import time
|
||
time.sleep(0.5)
|
||
|
||
except Exception as e:
|
||
# 发生异常时,更新状态为全量更新失败(2)
|
||
try:
|
||
with engine.connect() as conn:
|
||
update_status_sql = f"""
|
||
UPDATE stock_metadata
|
||
SET status = 2
|
||
WHERE ts_code = '{ts_code}'
|
||
"""
|
||
conn.execute(text(update_status_sql))
|
||
conn.commit()
|
||
except Exception:
|
||
pass # 更新状态失败时,忽略继续执行
|
||
|
||
results['failed_count'] += 1
|
||
print(f"[{process_id}] 股票 {ts_code} 更新失败: {str(e)}")
|
||
|
||
elif update_mode == 'incremental':
|
||
ts_code = stock_info['ts_code']
|
||
start_date = stock_info['start_date']
|
||
|
||
try:
|
||
# 简化日志,只显示进度
|
||
print(f"[{process_id}] 进度 {i + 1}/{total_stocks}: 增量更新 {ts_code}")
|
||
|
||
# 确保股票表存在
|
||
symbol = ts_code.split('.')[0]
|
||
with engine.connect() as conn:
|
||
table_exists_query = f"""
|
||
SELECT COUNT(*)
|
||
FROM information_schema.tables
|
||
WHERE table_schema = '{config['mysql']['database']}'
|
||
AND table_name = '{symbol}'
|
||
"""
|
||
table_exists = conn.execute(text(table_exists_query)).scalar() > 0
|
||
|
||
if not table_exists:
|
||
results['skipped_count'] += 1
|
||
continue # 表不存在则跳过
|
||
|
||
# 下载增量数据
|
||
new_data = download_stock_data(pro, ts_code, start_date, end_date)
|
||
|
||
if not new_data.empty:
|
||
# 获取现有数据
|
||
try:
|
||
with engine.connect() as conn:
|
||
existing_data = pd.read_sql(f"SELECT * FROM `{symbol}`", conn)
|
||
except Exception as e:
|
||
results['skipped_count'] += 1
|
||
continue
|
||
|
||
if not existing_data.empty:
|
||
# 转换日期列为相同的格式以便合并
|
||
existing_data['trade_date'] = pd.to_datetime(existing_data['trade_date'])
|
||
new_data['trade_date'] = pd.to_datetime(new_data['trade_date'])
|
||
|
||
# 删除可能重复的日期记录
|
||
existing_dates = set(existing_data['trade_date'])
|
||
new_data = new_data[~new_data['trade_date'].isin(existing_dates)]
|
||
|
||
if new_data.empty:
|
||
results['no_new_data_count'] += 1
|
||
continue
|
||
|
||
# 合并数据
|
||
combined_data = pd.concat([existing_data, new_data], ignore_index=True)
|
||
|
||
# 按日期排序
|
||
combined_data = combined_data.sort_values('trade_date')
|
||
|
||
# 保存合并后的数据
|
||
records_saved = save_stock_data(combined_data, engine, ts_code)
|
||
else:
|
||
# 如果表存在但为空,直接保存新数据
|
||
records_saved = save_stock_data(new_data, engine, ts_code)
|
||
|
||
# 更新元数据统计信息
|
||
update_metadata(engine, ts_code)
|
||
|
||
# 更新增量更新时间戳和状态为正常(1)
|
||
with engine.connect() as conn:
|
||
update_status_sql = f"""
|
||
UPDATE stock_metadata
|
||
SET last_incremental_update = '{update_time}',
|
||
status = 1
|
||
WHERE ts_code = '{ts_code}'
|
||
"""
|
||
conn.execute(text(update_status_sql))
|
||
conn.commit()
|
||
|
||
results['success_count'] += 1
|
||
else:
|
||
# 无新数据
|
||
results['no_new_data_count'] += 1
|
||
|
||
# 防止API限流
|
||
import time
|
||
time.sleep(0.5)
|
||
|
||
except Exception as e:
|
||
# 发生异常时,更新状态为增量更新失败(3)
|
||
try:
|
||
with engine.connect() as conn:
|
||
update_status_sql = f"""
|
||
UPDATE stock_metadata
|
||
SET status = 3
|
||
WHERE ts_code = '{ts_code}'
|
||
"""
|
||
conn.execute(text(update_status_sql))
|
||
conn.commit()
|
||
except Exception:
|
||
pass # 更新状态失败时,忽略继续执行
|
||
|
||
results['failed_count'] += 1
|
||
print(f"[{process_id}] 股票 {ts_code} 增量更新失败: {str(e)}")
|
||
|
||
results['processed_count'] += 1
|
||
|
||
return results
|
||
|
||
|
||
def perform_full_update(start_year=2020, processes=4, resume=False):
|
||
"""
|
||
执行全量更新:从指定年份开始到今天的所有数据
|
||
|
||
Args:
|
||
start_year: 开始年份,默认为2020年
|
||
processes: 进程数量,默认为4
|
||
resume: 是否只更新之前全量更新失败的股票
|
||
"""
|
||
update_type = "续传全量更新" if resume else "全量更新"
|
||
print(f"开始执行{update_type} (从{start_year}年至今),使用 {processes} 个进程...")
|
||
start_time = datetime.now()
|
||
|
||
# 加载配置
|
||
config = load_config()
|
||
|
||
# 创建数据库引擎
|
||
engine = create_engine_from_config(config)
|
||
|
||
# 设置Tushare API
|
||
ts.set_token(config['tushare_token'])
|
||
pro = ts.pro_api()
|
||
|
||
# 设置数据获取的日期范围
|
||
end_date = datetime.now().strftime('%Y%m%d')
|
||
start_date = f"{start_year}0101" # 从指定年份的1月1日开始
|
||
print(f"数据范围: {start_date} 至 {end_date}")
|
||
|
||
# 从元数据表获取股票代码
|
||
with engine.connect() as conn:
|
||
# 根据resume参数决定查询条件
|
||
if resume:
|
||
# 只查询全量更新失败的股票
|
||
query = """
|
||
SELECT ts_code
|
||
FROM stock_metadata
|
||
WHERE list_status = 'L' AND status != 1
|
||
"""
|
||
else:
|
||
# 查询所有已上市的股票
|
||
query = """
|
||
SELECT ts_code
|
||
FROM stock_metadata
|
||
WHERE list_status = 'L'
|
||
"""
|
||
|
||
result = conn.execute(text(query))
|
||
stock_codes = [row[0] for row in result]
|
||
|
||
if not stock_codes:
|
||
if resume:
|
||
print("没有找到需要续传的股票")
|
||
else:
|
||
print("没有找到需要更新的股票,请先确保元数据表已初始化")
|
||
return
|
||
|
||
print(f"共找到 {len(stock_codes)} 只股票需要{update_type}")
|
||
|
||
# 将股票列表分成多个批次
|
||
batches = split_list_into_chunks(stock_codes, processes)
|
||
|
||
# 创建进程池
|
||
with Pool(processes=processes) as pool:
|
||
# 使用partial函数固定部分参数
|
||
process_func = partial(
|
||
process_stock_batch,
|
||
config=config,
|
||
update_mode='full',
|
||
start_date=start_date,
|
||
end_date=end_date
|
||
)
|
||
|
||
# 启动多进程处理并收集结果
|
||
results = pool.map(process_func, batches)
|
||
|
||
# 汇总所有进程的结果
|
||
success_count = sum(r['success_count'] for r in results)
|
||
failed_count = sum(r['failed_count'] for r in results)
|
||
|
||
end_time = datetime.now()
|
||
duration = (end_time - start_time).total_seconds() / 60 # 转换为分钟
|
||
|
||
print(f"\n{update_type}完成!")
|
||
print(f"总耗时: {duration:.2f} 分钟")
|
||
print(f"成功更新: {success_count} 只股票")
|
||
print(f"更新失败: {failed_count} 只股票")
|
||
print(f"总计: {len(stock_codes)} 只股票")
|
||
|
||
|
||
def update_metadata(engine, ts_code):
|
||
"""
|
||
更新股票元数据中的统计信息,从股票表中直接获取数据
|
||
Args:
|
||
engine: 数据库连接引擎
|
||
ts_code: 股票代码(带后缀,如000001.SZ)
|
||
Returns:
|
||
bool: 更新成功返回True,否则返回False
|
||
"""
|
||
try:
|
||
# 提取股票代码(不含后缀)
|
||
symbol = ts_code.split('.')[0]
|
||
# 查询股票数据统计信息
|
||
stats_query = f"""
|
||
SELECT
|
||
MIN(trade_date) as min_date,
|
||
MAX(trade_date) as max_date,
|
||
COUNT(*) as record_count
|
||
FROM `{symbol}`
|
||
"""
|
||
# 查询最新交易日数据
|
||
latest_data_query = f"""
|
||
SELECT
|
||
close,
|
||
pe_ttm,
|
||
pb,
|
||
total_mv
|
||
FROM `{symbol}`
|
||
WHERE trade_date = (SELECT MAX(trade_date) FROM `{symbol}`)
|
||
LIMIT 1
|
||
"""
|
||
# 执行查询
|
||
with engine.connect() as conn:
|
||
# 获取统计信息
|
||
stats_result = conn.execute(text(stats_query)).fetchone()
|
||
if not stats_result:
|
||
# 更新状态为异常(2)
|
||
update_empty_sql = f"""
|
||
UPDATE stock_metadata
|
||
SET last_full_update = NOW(),
|
||
record_count = 0,
|
||
status = 2
|
||
WHERE ts_code = '{ts_code}'
|
||
"""
|
||
conn.execute(text(update_empty_sql))
|
||
conn.commit()
|
||
return False
|
||
|
||
# 处理日期字段
|
||
try:
|
||
data_start_date = stats_result[0].strftime('%Y-%m-%d') if stats_result[0] else 'NULL'
|
||
data_end_date = stats_result[1].strftime('%Y-%m-%d') if stats_result[1] else 'NULL'
|
||
except AttributeError:
|
||
# 处理日期可能是字符串而不是datetime对象的情况
|
||
data_start_date = stats_result[0] if stats_result[0] else 'NULL'
|
||
data_end_date = stats_result[1] if stats_result[1] else 'NULL'
|
||
|
||
record_count = stats_result[2] or 0
|
||
|
||
# 获取最新交易日数据
|
||
latest_data_result = conn.execute(text(latest_data_query)).fetchone()
|
||
|
||
# 设置默认值并处理NULL
|
||
default_value = 'NULL'
|
||
latest_price = default_value
|
||
latest_pe_ttm = default_value
|
||
latest_pb = default_value
|
||
latest_total_mv = default_value
|
||
|
||
# 如果有最新数据,则更新相应字段
|
||
if latest_data_result:
|
||
latest_price = str(latest_data_result[0]) if latest_data_result[0] is not None else default_value
|
||
latest_pe_ttm = str(latest_data_result[1]) if latest_data_result[1] is not None else default_value
|
||
latest_pb = str(latest_data_result[2]) if latest_data_result[2] is not None else default_value
|
||
latest_total_mv = str(latest_data_result[3]) if latest_data_result[3] is not None else default_value
|
||
|
||
# 构建更新SQL
|
||
update_sql = f"""
|
||
UPDATE stock_metadata
|
||
SET
|
||
data_start_date = CASE WHEN '{data_start_date}' = 'NULL' THEN NULL ELSE '{data_start_date}' END,
|
||
data_end_date = CASE WHEN '{data_end_date}' = 'NULL' THEN NULL ELSE '{data_end_date}' END,
|
||
record_count = {record_count},
|
||
status = CASE WHEN {record_count} > 0 THEN 1 ELSE 2 END,
|
||
latest_price = CASE WHEN '{latest_price}' = 'NULL' THEN NULL ELSE {latest_price} END,
|
||
latest_date = CASE WHEN '{data_end_date}' = 'NULL' THEN NULL ELSE '{data_end_date}' END,
|
||
latest_pe_ttm = CASE WHEN '{latest_pe_ttm}' = 'NULL' THEN NULL ELSE {latest_pe_ttm} END,
|
||
latest_pb = CASE WHEN '{latest_pb}' = 'NULL' THEN NULL ELSE {latest_pb} END,
|
||
latest_total_mv = CASE WHEN '{latest_total_mv}' = 'NULL' THEN NULL ELSE {latest_total_mv} END
|
||
WHERE ts_code = '{ts_code}'
|
||
"""
|
||
# 执行更新
|
||
conn.execute(text(update_sql))
|
||
conn.commit()
|
||
|
||
return True
|
||
except Exception as e:
|
||
# 尝试将状态更新为元数据更新失败(-1)
|
||
try:
|
||
with engine.connect() as conn:
|
||
error_update_sql = f"""
|
||
UPDATE stock_metadata
|
||
SET status = -1
|
||
WHERE ts_code = '{ts_code}'
|
||
"""
|
||
conn.execute(text(error_update_sql))
|
||
conn.commit()
|
||
except Exception:
|
||
pass # 更新状态失败时,忽略继续执行
|
||
|
||
print(f"更新({ts_code})元数据时出错: {str(e)}")
|
||
return False
|
||
|
||
|
||
def perform_incremental_update(processes=4):
|
||
"""
|
||
执行增量更新:从上次增量或全量更新日期到今天的数据
|
||
|
||
Args:
|
||
processes: 进程数量,默认为4
|
||
"""
|
||
print(f"开始执行增量更新,使用 {processes} 个进程...")
|
||
start_time = datetime.now()
|
||
|
||
# 加载配置
|
||
config = load_config()
|
||
|
||
# 创建数据库引擎
|
||
engine = create_engine_from_config(config)
|
||
|
||
# 设置Tushare API
|
||
ts.set_token(config['tushare_token'])
|
||
pro = ts.pro_api()
|
||
|
||
# 当前日期作为结束日期
|
||
end_date = datetime.now().strftime('%Y%m%d')
|
||
|
||
# 获取需要更新的股票及其上次更新日期
|
||
stocks_to_update = []
|
||
with engine.connect() as conn:
|
||
query = """
|
||
SELECT
|
||
ts_code,
|
||
last_incremental_update,
|
||
last_full_update
|
||
FROM
|
||
stock_metadata
|
||
WHERE
|
||
list_status = 'L'
|
||
"""
|
||
result = conn.execute(text(query))
|
||
for row in result:
|
||
ts_code = row[0]
|
||
last_incr_update = row[1] # 上次增量更新日期
|
||
last_full_update = row[2] # 上次全量更新日期
|
||
|
||
# 确定起始日期:使用最近的更新日期
|
||
latest_update = None
|
||
|
||
# 检查增量更新日期
|
||
if last_incr_update:
|
||
latest_update = last_incr_update
|
||
|
||
# 检查全量更新日期
|
||
if last_full_update:
|
||
if not latest_update or last_full_update > latest_update:
|
||
latest_update = last_full_update
|
||
|
||
# 如果有更新日期,使用其后一天作为起始日期
|
||
if latest_update:
|
||
start_date = (latest_update + timedelta(days=1)).strftime('%Y%m%d')
|
||
else:
|
||
# 如果没有任何更新记录,默认从30天前开始
|
||
start_date = (datetime.now() - timedelta(days=30)).strftime('%Y%m%d')
|
||
|
||
# 如果起始日期大于等于结束日期,则跳过此股票
|
||
if start_date >= end_date:
|
||
continue
|
||
|
||
stocks_to_update.append({
|
||
'ts_code': ts_code,
|
||
'start_date': start_date
|
||
})
|
||
|
||
if not stocks_to_update:
|
||
print("没有找到需要增量更新的股票")
|
||
return
|
||
|
||
print(f"共找到 {len(stocks_to_update)} 只股票需要增量更新")
|
||
|
||
# 当前更新时间
|
||
update_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||
|
||
# 将股票列表分成多个批次
|
||
batches = split_list_into_chunks(stocks_to_update, processes)
|
||
|
||
# 创建进程池
|
||
with Pool(processes=processes) as pool:
|
||
# 使用partial函数固定部分参数
|
||
process_func = partial(
|
||
process_stock_batch,
|
||
config=config,
|
||
update_mode='incremental',
|
||
start_date=None, # 增量模式下使用每个股票自己的起始日期
|
||
end_date=end_date,
|
||
update_time=update_time
|
||
)
|
||
|
||
# 启动多进程处理并收集结果
|
||
results = pool.map(process_func, batches)
|
||
|
||
# 汇总所有进程的结果
|
||
success_count = sum(r['success_count'] for r in results)
|
||
failed_count = sum(r['failed_count'] for r in results)
|
||
no_new_data_count = sum(r['no_new_data_count'] for r in results)
|
||
skipped_count = sum(r['skipped_count'] for r in results)
|
||
|
||
end_time = datetime.now()
|
||
duration = (end_time - start_time).total_seconds() / 60 # 转换为分钟
|
||
|
||
print("\n增量更新完成!")
|
||
print(f"总耗时: {duration:.2f} 分钟")
|
||
print(f"成功更新: {success_count} 只股票")
|
||
print(f"无新数据: {no_new_data_count} 只股票")
|
||
print(f"跳过股票: {skipped_count} 只股票")
|
||
print(f"更新失败: {failed_count} 只股票")
|
||
print(f"总计: {len(stocks_to_update)} 只股票")
|
||
|
||
|
||
# 添加创建日历表的函数
|
||
def create_calendar_table(engine):
|
||
"""创建交易日历表"""
|
||
# 从sql文件读取创建表的SQL
|
||
with open('sql/calendar.sql', 'r', encoding='utf-8') as f:
|
||
create_table_sql = f.read()
|
||
|
||
with engine.connect() as conn:
|
||
conn.execute(text(create_table_sql))
|
||
conn.commit()
|
||
print("交易日历表创建成功")
|
||
|
||
|
||
# 添加下载日历数据的函数
|
||
def download_calendar_data(pro, start_date=None, end_date=None):
|
||
"""
|
||
下载交易日历数据
|
||
|
||
Args:
|
||
pro: Tushare API对象
|
||
start_date: 开始日期,格式YYYYMMDD,默认为2000年初
|
||
end_date: 结束日期,格式YYYYMMDD,默认为当前日期之后的一年
|
||
|
||
Returns:
|
||
pandas.DataFrame: 交易日历数据
|
||
"""
|
||
# 默认获取从2000年至今后一年的日历数据
|
||
if start_date is None:
|
||
start_date = '20000101'
|
||
|
||
if end_date is None:
|
||
# 获取当前日期后一年的数据,以便预先有足够的交易日历
|
||
current_date = datetime.now()
|
||
next_year = current_date + timedelta(days=365)
|
||
end_date = next_year.strftime('%Y%m%d')
|
||
|
||
print(f"正在下载从 {start_date} 到 {end_date} 的交易日历数据...")
|
||
|
||
try:
|
||
# 调用Tushare的trade_cal接口获取交易日历
|
||
df = pro.trade_cal(exchange='SSE', start_date=start_date, end_date=end_date)
|
||
|
||
# 重命名列以符合数据库表结构
|
||
df = df.rename(columns={
|
||
'cal_date': 'trade_date',
|
||
'exchange': 'exchange',
|
||
'is_open': 'is_open',
|
||
'pretrade_date': 'pretrade_date'
|
||
})
|
||
|
||
# 将日期格式从YYYYMMDD转换为YYYY-MM-DD
|
||
df['trade_date'] = pd.to_datetime(df['trade_date']).dt.strftime('%Y-%m-%d')
|
||
if 'pretrade_date' in df.columns:
|
||
df['pretrade_date'] = pd.to_datetime(df['pretrade_date']).dt.strftime('%Y-%m-%d')
|
||
|
||
# 添加创建和更新时间
|
||
current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||
df['create_time'] = current_time
|
||
df['update_time'] = current_time
|
||
|
||
print(f"成功下载 {len(df)} 条交易日历记录")
|
||
return df
|
||
|
||
except Exception as e:
|
||
print(f"下载交易日历数据时出错: {e}")
|
||
return pd.DataFrame()
|
||
|
||
|
||
# 添加保存日历数据的函数
|
||
def save_calendar_data(calendar_df, engine):
|
||
"""
|
||
将交易日历数据保存到数据库
|
||
|
||
Args:
|
||
calendar_df: 包含交易日历数据的DataFrame
|
||
engine: SQLAlchemy引擎
|
||
"""
|
||
if calendar_df.empty:
|
||
print("没有交易日历数据可保存")
|
||
return
|
||
|
||
try:
|
||
# 使用to_sql将数据写入数据库,如果表存在则替换
|
||
calendar_df.to_sql('calendar', engine, if_exists='replace', index=False)
|
||
print(f"成功保存 {len(calendar_df)} 条交易日历数据到数据库")
|
||
except Exception as e:
|
||
print(f"保存交易日历数据时出错: {e}")
|
||
|
||
|
||
# 添加交易日历更新函数
|
||
def update_calendar(engine, pro):
|
||
"""
|
||
更新交易日历数据
|
||
|
||
Args:
|
||
engine: SQLAlchemy引擎
|
||
pro: Tushare API对象
|
||
"""
|
||
try:
|
||
# 检查calendar表是否存在
|
||
with engine.connect() as conn:
|
||
result = conn.execute(text("SHOW TABLES LIKE 'calendar'"))
|
||
table_exists = result.fetchone() is not None
|
||
|
||
if not table_exists:
|
||
# 如果表不存在,创建表并下载所有历史数据
|
||
create_calendar_table(engine)
|
||
calendar_df = download_calendar_data(pro)
|
||
save_calendar_data(calendar_df, engine)
|
||
else:
|
||
# 如果表存在,查询最新的交易日期
|
||
with engine.connect() as conn:
|
||
result = conn.execute(text("SELECT MAX(trade_date) FROM calendar"))
|
||
latest_date = result.fetchone()[0]
|
||
|
||
if latest_date:
|
||
# 将日期转换为YYYYMMDD格式
|
||
start_date = (datetime.strptime(str(latest_date), '%Y-%m-%d') + timedelta(days=1)).strftime('%Y%m%d')
|
||
# 下载新数据
|
||
calendar_df = download_calendar_data(pro, start_date=start_date)
|
||
|
||
if not calendar_df.empty:
|
||
# 将新数据追加到表中
|
||
calendar_df.to_sql('calendar', engine, if_exists='append', index=False)
|
||
print(f"成功更新 {len(calendar_df)} 条新的交易日历数据")
|
||
else:
|
||
print("没有新的交易日历数据需要更新")
|
||
else:
|
||
# 如果表为空,下载所有历史数据
|
||
calendar_df = download_calendar_data(pro)
|
||
save_calendar_data(calendar_df, engine)
|
||
|
||
except Exception as e:
|
||
print(f"更新交易日历时出错: {e}")
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description='股票数据下载器')
|
||
parser.add_argument('--init', action='store_true', help='初始化元数据')
|
||
parser.add_argument('--mode', choices=['full', 'incremental', 'both'], default='full',
|
||
help='更新模式: full(全量), incremental(增量), both(两者都)')
|
||
parser.add_argument('--year', type=int, default=2020, help='起始年份(用于全量更新)')
|
||
parser.add_argument('--processes', type=int, default=4, help='并行进程数')
|
||
parser.add_argument('--resume', action='store_true', help='是否从中断处继续')
|
||
parser.add_argument('--update-calendar', action='store_true', help='更新交易日历数据')
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 加载配置
|
||
config = load_config()
|
||
# 创建数据库引擎
|
||
engine = create_engine_from_config(config)
|
||
# 设置Tushare API
|
||
ts.set_token(config['tushare_token'])
|
||
pro = ts.pro_api()
|
||
|
||
# 如果需要初始化元数据
|
||
if args.init:
|
||
# 创建元数据表
|
||
create_metadata_table(engine)
|
||
# 下载并保存股票元数据
|
||
metadata_df = download_stock_metadata(pro)
|
||
save_metadata(metadata_df, engine)
|
||
print("元数据初始化完成")
|
||
|
||
# 初始化交易日历数据
|
||
update_calendar(engine, pro)
|
||
print("交易日历初始化完成")
|
||
|
||
# 如果需要更新交易日历
|
||
if args.update_calendar:
|
||
update_calendar(engine, pro)
|
||
print("交易日历更新完成")
|
||
|
||
# 根据选择的模式执行更新
|
||
if args.mode == 'full' or args.mode == 'both':
|
||
perform_full_update(args.year, args.processes, args.resume)
|
||
|
||
if args.mode == 'incremental' or args.mode == 'both':
|
||
perform_incremental_update(args.processes)
|
||
|
||
|
||
if __name__ == '__main__':
|
||
main()
|