backtrader/data_downloader.py
2025-04-05 22:39:24 +08:00

1060 lines
38 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import argparse
import pandas as pd
import tushare as ts
import numpy as np
import os
import multiprocessing
from multiprocessing import Pool
from functools import partial
from sqlalchemy import create_engine, text
from datetime import datetime, timedelta
from utils import load_config
def create_engine_from_config(config):
mysql = config['mysql']
connection_string = f"mysql+pymysql://{mysql['user']}:{mysql['password']}@{mysql['host']}:{mysql['port']}/{mysql['database']}?charset={mysql['charset']}&use_unicode=1"
return create_engine(connection_string)
def create_metadata_table(engine):
"""创建股票元数据表"""
# 从sql文件读取创建表的SQL
with open('sql/meta.sql', 'r', encoding='utf-8') as f:
create_table_sql = f.read()
with engine.connect() as conn:
conn.execute(text(create_table_sql))
conn.commit()
print("股票元数据表创建成功")
def create_stock_table(engine, ts_code):
"""为指定股票创建数据表"""
# 确保sql目录存在
os.makedirs('sql', exist_ok=True)
# 读取模板SQL
with open('sql/stock_table_template.sql', 'r', encoding='utf-8') as f:
template_sql = f.read()
# 替换模板中的symbol
create_table_sql = template_sql.format(symbol=ts_code.split('.')[0])
# 执行创建表的SQL
with engine.connect() as conn:
conn.execute(text(create_table_sql))
conn.commit()
print(f"股票({ts_code})数据表创建成功")
def download_stock_metadata(pro):
"""下载股票基础信息"""
df = pro.stock_basic(exchange='', list_status='L',
fields='ts_code,symbol,name,area,industry,fullname,enname,cnspell,market,exchange,curr_type,list_status,list_date,delist_date,is_hs,act_name,act_ent_type')
# 添加状态字段
df['status'] = 0 # 未初始化
df['remark'] = '自动导入'
df['last_full_update'] = None
df['last_incremental_update'] = None
df['data_start_date'] = None
df['data_end_date'] = None
df['record_count'] = 0
df['latest_price'] = None
df['latest_date'] = None
df['latest_pe_ttm'] = None
df['latest_pb'] = None
df['latest_total_mv'] = None
return df
def save_metadata(df, engine):
"""保存元数据到数据库,只更新指定字段发生变化的记录,处理日期格式问题"""
# 需要监控变化的基础元数据字段
metadata_fields = [
'ts_code', 'symbol', 'name', 'area', 'industry', 'fullname',
'enname', 'cnspell', 'market', 'exchange', 'curr_type',
'list_status', 'list_date', 'delist_date', 'is_hs',
'act_name', 'act_ent_type'
]
# 日期字段列表
date_fields = ['list_date', 'delist_date']
# 从数据库中读取现有的元数据
try:
with engine.connect() as conn:
# 只查询我们关心的字段
query = f"SELECT {', '.join(metadata_fields)} FROM stock_metadata"
existing_df = pd.read_sql(query, conn)
except Exception as e:
print(f"读取现有元数据时出错: {str(e)}")
existing_df = pd.DataFrame() # 表不存在时创建空DataFrame
if existing_df.empty:
# 如果表是空的,直接插入所有数据
result = df.to_sql('stock_metadata', engine, index=False,
if_exists='append', chunksize=1000)
print(f"成功保存{result}条股票元数据(新建)")
return result
else:
# 数据预处理:处理日期格式
df_processed = df.copy()
existing_processed = existing_df.copy()
# 标准化日期格式进行比较
def normalize_date(date_str):
if pd.isna(date_str) or date_str == '' or date_str == 'None' or date_str is None:
return ''
# 去除所有非数字字符
date_str = str(date_str).strip()
digits_only = ''.join(c for c in date_str if c.isdigit())
# 如果是8位数字形式的日期返回标准格式
if len(digits_only) == 8:
return digits_only
return ''
# 对两个DataFrame应用日期标准化
for field in date_fields:
if field in df_processed.columns and field in existing_processed.columns:
df_processed[field] = df_processed[field].apply(normalize_date)
existing_processed[field] = existing_processed[field].apply(normalize_date)
# 对其他非日期字段进行标准化处理
for col in [f for f in metadata_fields if f not in date_fields and f != 'ts_code']:
if col in df_processed.columns and col in existing_processed.columns:
# 将两个DataFrame的相同列转换为相同的数据类型字符串类型
df_processed[col] = df_processed[col].astype(str)
existing_processed[col] = existing_processed[col].astype(str)
# 处理NaN值
df_processed[col] = df_processed[col].fillna('').str.strip()
existing_processed[col] = existing_processed[col].fillna('').str.strip()
# 找出新增的记录
existing_ts_codes = set(existing_processed['ts_code'])
new_records = df[~df['ts_code'].isin(existing_ts_codes)].copy()
# 找出需要更新的记录
changed_records = []
unchanged_count = 0
# 对于每个已存在的ts_code检查是否有变化
for ts_code in df_processed[df_processed['ts_code'].isin(existing_ts_codes)]['ts_code']:
# 获取新旧记录(已经标准化后的)
new_record = df_processed[df_processed['ts_code'] == ts_code].iloc[0]
old_record = existing_processed[existing_processed['ts_code'] == ts_code].iloc[0]
# 同时取原始记录用于更新操作
original_record = df[df['ts_code'] == ts_code].iloc[0]
# 检查是否有变化
has_change = False
changed_fields = []
for field in metadata_fields:
if field == 'ts_code': # 跳过主键
continue
new_val = new_record[field]
old_val = old_record[field]
if new_val != old_val:
has_change = True
changed_fields.append(field)
print(f"发现变化 - {ts_code}{field}: '{old_val}' -> '{new_val}'")
if has_change:
changed_records.append({
'record': original_record, # 使用原始记录,保持原始格式
'changed_fields': changed_fields
})
else:
unchanged_count += 1
# 插入新记录
new_count = 0
if not new_records.empty:
new_count = new_records.to_sql('stock_metadata', engine, index=False,
if_exists='append', chunksize=1000)
# 更新变化的记录
updated_count = 0
for change_info in changed_records:
record = change_info['record']
fields = change_info['changed_fields']
# 构建更新语句,只更新变化的字段
fields_to_update = []
params = {'ts_code': record['ts_code']}
for field in fields:
fields_to_update.append(f"{field} = :{field}")
params[field] = record[field]
if fields_to_update:
update_stmt = text(f"""
UPDATE stock_metadata
SET {', '.join(fields_to_update)}
WHERE ts_code = :ts_code
""")
with engine.connect() as conn:
conn.execute(update_stmt, params)
updated_count += 1
conn.commit()
print(f"元数据更新统计:")
print(f" • 新增记录: {new_count}")
print(f" • 更新记录: {updated_count}")
print(f" • 无变化记录: {unchanged_count}")
print(f" • 总处理记录: {new_count + updated_count + unchanged_count}")
return new_count + updated_count
def download_stock_data(pro, ts_code, start_date, end_date):
"""下载股票的所有类型数据并合并"""
# 下载daily价格数据
daily_df = pro.daily(ts_code=ts_code, start_date=start_date, end_date=end_date)
if daily_df.empty:
print(f"警告:{ts_code}没有找到daily数据")
return pd.DataFrame()
# 下载daily_basic数据
daily_basic_df = pro.daily_basic(ts_code=ts_code, start_date=start_date, end_date=end_date)
# 下载moneyflow数据
moneyflow_df = pro.moneyflow(ts_code=ts_code, start_date=start_date, end_date=end_date)
# 确保每个DataFrame都有trade_date列作为合并键
if 'trade_date' not in daily_df.columns:
print(f"错误:{ts_code}的daily数据缺少trade_date列")
return pd.DataFrame()
# 为方便处理,确保所有日期列是字符串类型
daily_df['trade_date'] = daily_df['trade_date'].astype(str)
# 通过merge而不是join合并数据这样可以更好地控制列
result_df = daily_df
# 合并daily_basic数据
if not daily_basic_df.empty:
daily_basic_df['trade_date'] = daily_basic_df['trade_date'].astype(str)
# 识别重叠的列除了ts_code和trade_date
overlap_cols = list(set(result_df.columns) & set(daily_basic_df.columns) - {'ts_code', 'trade_date'})
# 从daily_basic中排除这些重叠列
daily_basic_df_filtered = daily_basic_df.drop(columns=overlap_cols)
# 合并数据
result_df = pd.merge(
result_df,
daily_basic_df_filtered,
on=['ts_code', 'trade_date'],
how='left'
)
# 合并moneyflow数据
if not moneyflow_df.empty:
moneyflow_df['trade_date'] = moneyflow_df['trade_date'].astype(str)
# 识别重叠的列除了ts_code和trade_date
overlap_cols = list(set(result_df.columns) & set(moneyflow_df.columns) - {'ts_code', 'trade_date'})
# 从moneyflow中排除这些重叠列
moneyflow_df_filtered = moneyflow_df.drop(columns=overlap_cols)
# 合并数据
result_df = pd.merge(
result_df,
moneyflow_df_filtered,
on=['ts_code', 'trade_date'],
how='left'
)
# 将trade_date转换为datetime格式
result_df['trade_date'] = pd.to_datetime(result_df['trade_date'])
return result_df
def save_stock_data(df, engine, ts_code, if_exists='replace'):
"""
保存股票数据到对应的表中
Args:
df: 股票数据DataFrame
engine: 数据库引擎
ts_code: 股票代码
if_exists: 如果表已存在,处理方式。'replace': 替换现有表,'append': 附加到现有表
Returns:
int: 保存的记录数量
"""
if df.empty:
print(f"警告:{ts_code}没有数据可保存")
return 0
# 删除ts_code列因为表中不需要
if 'ts_code' in df.columns:
df = df.drop(columns=['ts_code'])
# 使用指定的模式保存数据
try:
symbol = ts_code.split('.')[0]
result = df.to_sql(f'{symbol}', engine, index=False,
if_exists=if_exists, chunksize=1000)
print(f"成功保存{result}{ts_code}股票数据 (模式: {if_exists})")
return result
except Exception as e:
print(f"保存{ts_code}数据时出错: {str(e)}")
return 0
def split_list_into_chunks(data_list, num_chunks):
"""
将列表分割成指定数量的块
Args:
data_list: 要分割的列表
num_chunks: 块的数量
Returns:
list: 包含多个子列表的列表
"""
avg = len(data_list) / float(num_chunks)
chunks = []
last = 0.0
while last < len(data_list):
chunks.append(data_list[int(last):int(last + avg)])
last += avg
return chunks
def process_stock_batch(batch, config, update_mode, start_date, end_date):
"""
处理一批股票的工作函数(在单独进程中运行)
Args:
batch: 要处理的股票列表
config: 配置信息
update_mode: 更新模式 'full''incremental'
start_date: 开始日期
end_date: 结束日期
Returns:
dict: 包含处理结果统计信息的字典
"""
# 为当前进程创建新的数据库连接和API连接
engine = create_engine_from_config(config)
ts.set_token(config['tushare_token'])
pro = ts.pro_api()
# 结果统计
results = {
'success_count': 0,
'failed_count': 0,
'no_new_data_count': 0,
'skipped_count': 0,
'processed_count': 0
}
# 获取当前进程ID
process_id = multiprocessing.current_process().name
total_stocks = len(batch)
for i, stock_info in enumerate(batch):
# 当前更新时间
update_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
# 根据更新模式获取不同的参数
if update_mode == 'full':
ts_code = stock_info
try:
# 简化日志,只显示进度
print(f"[{process_id}] 进度 {i + 1}/{total_stocks}: 全量更新 {ts_code}")
# 检查股票表是否存在,不存在则创建
symbol = ts_code.split('.')[0]
with engine.connect() as conn:
table_exists_query = f"""
SELECT COUNT(*)
FROM information_schema.tables
WHERE table_schema = '{config['mysql']['database']}'
AND table_name = '{symbol}'
"""
table_exists = conn.execute(text(table_exists_query)).scalar() > 0
if not table_exists:
create_stock_table(engine, ts_code)
# 下载股票数据
stock_data = download_stock_data(pro, ts_code, start_date, end_date)
# 如果下载成功,保存数据并更新元数据
if not stock_data.empty:
# 使用replace模式保存数据全量替换
records_saved = save_stock_data(stock_data, engine, ts_code)
# 更新元数据统计信息
update_metadata(engine, ts_code)
# 更新全量更新时间戳和状态为正常(1)
with engine.connect() as conn:
update_status_sql = f"""
UPDATE stock_metadata
SET last_full_update = '{update_time}',
status = 1
WHERE ts_code = '{ts_code}'
"""
conn.execute(text(update_status_sql))
conn.commit()
results['success_count'] += 1
else:
# 更新状态为全量更新失败(2)
with engine.connect() as conn:
update_status_sql = f"""
UPDATE stock_metadata
SET status = 2
WHERE ts_code = '{ts_code}'
"""
conn.execute(text(update_status_sql))
conn.commit()
results['failed_count'] += 1
# 防止API限流
import time
time.sleep(0.5)
except Exception as e:
# 发生异常时,更新状态为全量更新失败(2)
try:
with engine.connect() as conn:
update_status_sql = f"""
UPDATE stock_metadata
SET status = 2
WHERE ts_code = '{ts_code}'
"""
conn.execute(text(update_status_sql))
conn.commit()
except Exception:
pass # 更新状态失败时,忽略继续执行
results['failed_count'] += 1
print(f"[{process_id}] 股票 {ts_code} 更新失败: {str(e)}")
elif update_mode == 'incremental':
ts_code = stock_info['ts_code']
start_date = stock_info['start_date']
try:
# 简化日志,只显示进度
print(f"[{process_id}] 进度 {i + 1}/{total_stocks}: 增量更新 {ts_code}")
# 确保股票表存在
symbol = ts_code.split('.')[0]
with engine.connect() as conn:
table_exists_query = f"""
SELECT COUNT(*)
FROM information_schema.tables
WHERE table_schema = '{config['mysql']['database']}'
AND table_name = '{symbol}'
"""
table_exists = conn.execute(text(table_exists_query)).scalar() > 0
if not table_exists:
results['skipped_count'] += 1
continue # 表不存在则跳过
# 下载增量数据
new_data = download_stock_data(pro, ts_code, start_date, end_date)
if not new_data.empty:
# 获取现有数据
try:
with engine.connect() as conn:
existing_data = pd.read_sql(f"SELECT * FROM `{symbol}`", conn)
except Exception as e:
results['skipped_count'] += 1
continue
if not existing_data.empty:
# 转换日期列为相同的格式以便合并
existing_data['trade_date'] = pd.to_datetime(existing_data['trade_date'])
new_data['trade_date'] = pd.to_datetime(new_data['trade_date'])
# 删除可能重复的日期记录
existing_dates = set(existing_data['trade_date'])
new_data = new_data[~new_data['trade_date'].isin(existing_dates)]
if new_data.empty:
results['no_new_data_count'] += 1
continue
# 合并数据
combined_data = pd.concat([existing_data, new_data], ignore_index=True)
# 按日期排序
combined_data = combined_data.sort_values('trade_date')
# 保存合并后的数据
records_saved = save_stock_data(combined_data, engine, ts_code)
else:
# 如果表存在但为空,直接保存新数据
records_saved = save_stock_data(new_data, engine, ts_code)
# 更新元数据统计信息
update_metadata(engine, ts_code)
# 更新增量更新时间戳和状态为正常(1)
with engine.connect() as conn:
update_status_sql = f"""
UPDATE stock_metadata
SET last_incremental_update = '{update_time}',
status = 1
WHERE ts_code = '{ts_code}'
"""
conn.execute(text(update_status_sql))
conn.commit()
results['success_count'] += 1
else:
# 无新数据
results['no_new_data_count'] += 1
# 防止API限流
import time
time.sleep(0.5)
except Exception as e:
# 发生异常时,更新状态为增量更新失败(3)
try:
with engine.connect() as conn:
update_status_sql = f"""
UPDATE stock_metadata
SET status = 3
WHERE ts_code = '{ts_code}'
"""
conn.execute(text(update_status_sql))
conn.commit()
except Exception:
pass # 更新状态失败时,忽略继续执行
results['failed_count'] += 1
print(f"[{process_id}] 股票 {ts_code} 增量更新失败: {str(e)}")
results['processed_count'] += 1
return results
def perform_full_update(start_year=2020, processes=4, resume=False):
"""
执行全量更新:从指定年份开始到今天的所有数据
Args:
start_year: 开始年份默认为2020年
processes: 进程数量默认为4
resume: 是否只更新之前全量更新失败的股票
"""
update_type = "续传全量更新" if resume else "全量更新"
print(f"开始执行{update_type} (从{start_year}年至今),使用 {processes} 个进程...")
start_time = datetime.now()
# 加载配置
config = load_config()
# 创建数据库引擎
engine = create_engine_from_config(config)
# 设置Tushare API
ts.set_token(config['tushare_token'])
pro = ts.pro_api()
# 设置数据获取的日期范围
end_date = datetime.now().strftime('%Y%m%d')
start_date = f"{start_year}0101" # 从指定年份的1月1日开始
print(f"数据范围: {start_date}{end_date}")
# 从元数据表获取股票代码
with engine.connect() as conn:
# 根据resume参数决定查询条件
if resume:
# 只查询全量更新失败的股票
query = """
SELECT ts_code
FROM stock_metadata
WHERE list_status = 'L' AND status != 1
"""
else:
# 查询所有已上市的股票
query = """
SELECT ts_code
FROM stock_metadata
WHERE list_status = 'L'
"""
result = conn.execute(text(query))
stock_codes = [row[0] for row in result]
if not stock_codes:
if resume:
print("没有找到需要续传的股票")
else:
print("没有找到需要更新的股票,请先确保元数据表已初始化")
return
print(f"共找到 {len(stock_codes)} 只股票需要{update_type}")
# 将股票列表分成多个批次
batches = split_list_into_chunks(stock_codes, processes)
# 创建进程池
with Pool(processes=processes) as pool:
# 使用partial函数固定部分参数
process_func = partial(
process_stock_batch,
config=config,
update_mode='full',
start_date=start_date,
end_date=end_date
)
# 启动多进程处理并收集结果
results = pool.map(process_func, batches)
# 汇总所有进程的结果
success_count = sum(r['success_count'] for r in results)
failed_count = sum(r['failed_count'] for r in results)
end_time = datetime.now()
duration = (end_time - start_time).total_seconds() / 60 # 转换为分钟
print(f"\n{update_type}完成!")
print(f"总耗时: {duration:.2f} 分钟")
print(f"成功更新: {success_count} 只股票")
print(f"更新失败: {failed_count} 只股票")
print(f"总计: {len(stock_codes)} 只股票")
def update_metadata(engine, ts_code):
"""
更新股票元数据中的统计信息,从股票表中直接获取数据
Args:
engine: 数据库连接引擎
ts_code: 股票代码(带后缀如000001.SZ)
Returns:
bool: 更新成功返回True否则返回False
"""
try:
# 提取股票代码(不含后缀)
symbol = ts_code.split('.')[0]
# 查询股票数据统计信息
stats_query = f"""
SELECT
MIN(trade_date) as min_date,
MAX(trade_date) as max_date,
COUNT(*) as record_count
FROM `{symbol}`
"""
# 查询最新交易日数据
latest_data_query = f"""
SELECT
close,
pe_ttm,
pb,
total_mv
FROM `{symbol}`
WHERE trade_date = (SELECT MAX(trade_date) FROM `{symbol}`)
LIMIT 1
"""
# 执行查询
with engine.connect() as conn:
# 获取统计信息
stats_result = conn.execute(text(stats_query)).fetchone()
if not stats_result:
# 更新状态为异常(2)
update_empty_sql = f"""
UPDATE stock_metadata
SET last_full_update = NOW(),
record_count = 0,
status = 2
WHERE ts_code = '{ts_code}'
"""
conn.execute(text(update_empty_sql))
conn.commit()
return False
# 处理日期字段
try:
data_start_date = stats_result[0].strftime('%Y-%m-%d') if stats_result[0] else 'NULL'
data_end_date = stats_result[1].strftime('%Y-%m-%d') if stats_result[1] else 'NULL'
except AttributeError:
# 处理日期可能是字符串而不是datetime对象的情况
data_start_date = stats_result[0] if stats_result[0] else 'NULL'
data_end_date = stats_result[1] if stats_result[1] else 'NULL'
record_count = stats_result[2] or 0
# 获取最新交易日数据
latest_data_result = conn.execute(text(latest_data_query)).fetchone()
# 设置默认值并处理NULL
default_value = 'NULL'
latest_price = default_value
latest_pe_ttm = default_value
latest_pb = default_value
latest_total_mv = default_value
# 如果有最新数据,则更新相应字段
if latest_data_result:
latest_price = str(latest_data_result[0]) if latest_data_result[0] is not None else default_value
latest_pe_ttm = str(latest_data_result[1]) if latest_data_result[1] is not None else default_value
latest_pb = str(latest_data_result[2]) if latest_data_result[2] is not None else default_value
latest_total_mv = str(latest_data_result[3]) if latest_data_result[3] is not None else default_value
# 构建更新SQL
update_sql = f"""
UPDATE stock_metadata
SET
data_start_date = CASE WHEN '{data_start_date}' = 'NULL' THEN NULL ELSE '{data_start_date}' END,
data_end_date = CASE WHEN '{data_end_date}' = 'NULL' THEN NULL ELSE '{data_end_date}' END,
record_count = {record_count},
status = CASE WHEN {record_count} > 0 THEN 1 ELSE 2 END,
latest_price = CASE WHEN '{latest_price}' = 'NULL' THEN NULL ELSE {latest_price} END,
latest_date = CASE WHEN '{data_end_date}' = 'NULL' THEN NULL ELSE '{data_end_date}' END,
latest_pe_ttm = CASE WHEN '{latest_pe_ttm}' = 'NULL' THEN NULL ELSE {latest_pe_ttm} END,
latest_pb = CASE WHEN '{latest_pb}' = 'NULL' THEN NULL ELSE {latest_pb} END,
latest_total_mv = CASE WHEN '{latest_total_mv}' = 'NULL' THEN NULL ELSE {latest_total_mv} END
WHERE ts_code = '{ts_code}'
"""
# 执行更新
conn.execute(text(update_sql))
conn.commit()
return True
except Exception as e:
# 尝试将状态更新为元数据更新失败(-1)
try:
with engine.connect() as conn:
error_update_sql = f"""
UPDATE stock_metadata
SET status = -1
WHERE ts_code = '{ts_code}'
"""
conn.execute(text(error_update_sql))
conn.commit()
except Exception:
pass # 更新状态失败时,忽略继续执行
print(f"更新({ts_code})元数据时出错: {str(e)}")
return False
def perform_incremental_update(processes=4):
"""
执行增量更新:从上次增量或全量更新日期到今天的数据
Args:
processes: 进程数量默认为4
"""
print(f"开始执行增量更新,使用 {processes} 个进程...")
start_time = datetime.now()
# 加载配置
config = load_config()
# 创建数据库引擎
engine = create_engine_from_config(config)
# 设置Tushare API
ts.set_token(config['tushare_token'])
pro = ts.pro_api()
# 当前日期作为结束日期
end_date = datetime.now().strftime('%Y%m%d')
# 获取需要更新的股票及其上次更新日期
stocks_to_update = []
with engine.connect() as conn:
query = """
SELECT
ts_code,
last_incremental_update,
last_full_update
FROM
stock_metadata
WHERE
list_status = 'L'
"""
result = conn.execute(text(query))
for row in result:
ts_code = row[0]
last_incr_update = row[1] # 上次增量更新日期
last_full_update = row[2] # 上次全量更新日期
# 确定起始日期:使用最近的更新日期
latest_update = None
# 检查增量更新日期
if last_incr_update:
latest_update = last_incr_update
# 检查全量更新日期
if last_full_update:
if not latest_update or last_full_update > latest_update:
latest_update = last_full_update
# 如果有更新日期,使用其后一天作为起始日期
if latest_update:
start_date = (latest_update + timedelta(days=1)).strftime('%Y%m%d')
else:
# 如果没有任何更新记录默认从30天前开始
start_date = (datetime.now() - timedelta(days=30)).strftime('%Y%m%d')
# 如果起始日期大于等于结束日期,则跳过此股票
if start_date >= end_date:
continue
stocks_to_update.append({
'ts_code': ts_code,
'start_date': start_date
})
if not stocks_to_update:
print("没有找到需要增量更新的股票")
return
print(f"共找到 {len(stocks_to_update)} 只股票需要增量更新")
# 当前更新时间
update_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
# 将股票列表分成多个批次
batches = split_list_into_chunks(stocks_to_update, processes)
# 创建进程池
with Pool(processes=processes) as pool:
# 使用partial函数固定部分参数
process_func = partial(
process_stock_batch,
config=config,
update_mode='incremental',
start_date=None, # 增量模式下使用每个股票自己的起始日期
end_date=end_date,
update_time=update_time
)
# 启动多进程处理并收集结果
results = pool.map(process_func, batches)
# 汇总所有进程的结果
success_count = sum(r['success_count'] for r in results)
failed_count = sum(r['failed_count'] for r in results)
no_new_data_count = sum(r['no_new_data_count'] for r in results)
skipped_count = sum(r['skipped_count'] for r in results)
end_time = datetime.now()
duration = (end_time - start_time).total_seconds() / 60 # 转换为分钟
print("\n增量更新完成!")
print(f"总耗时: {duration:.2f} 分钟")
print(f"成功更新: {success_count} 只股票")
print(f"无新数据: {no_new_data_count} 只股票")
print(f"跳过股票: {skipped_count} 只股票")
print(f"更新失败: {failed_count} 只股票")
print(f"总计: {len(stocks_to_update)} 只股票")
# 添加创建日历表的函数
def create_calendar_table(engine):
"""创建交易日历表"""
# 从sql文件读取创建表的SQL
with open('sql/calendar.sql', 'r', encoding='utf-8') as f:
create_table_sql = f.read()
with engine.connect() as conn:
conn.execute(text(create_table_sql))
conn.commit()
print("交易日历表创建成功")
# 添加下载日历数据的函数
def download_calendar_data(pro, start_date=None, end_date=None):
"""
下载交易日历数据
Args:
pro: Tushare API对象
start_date: 开始日期格式YYYYMMDD默认为2000年初
end_date: 结束日期格式YYYYMMDD默认为当前日期之后的一年
Returns:
pandas.DataFrame: 交易日历数据
"""
# 默认获取从2000年至今后一年的日历数据
if start_date is None:
start_date = '20000101'
if end_date is None:
# 获取当前日期后一年的数据,以便预先有足够的交易日历
current_date = datetime.now()
next_year = current_date + timedelta(days=365)
end_date = next_year.strftime('%Y%m%d')
print(f"正在下载从 {start_date}{end_date} 的交易日历数据...")
try:
# 调用Tushare的trade_cal接口获取交易日历
df = pro.trade_cal(exchange='SSE', start_date=start_date, end_date=end_date)
# 重命名列以符合数据库表结构
df = df.rename(columns={
'cal_date': 'trade_date',
'exchange': 'exchange',
'is_open': 'is_open',
'pretrade_date': 'pretrade_date'
})
# 将日期格式从YYYYMMDD转换为YYYY-MM-DD
df['trade_date'] = pd.to_datetime(df['trade_date']).dt.strftime('%Y-%m-%d')
if 'pretrade_date' in df.columns:
df['pretrade_date'] = pd.to_datetime(df['pretrade_date']).dt.strftime('%Y-%m-%d')
# 添加创建和更新时间
current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
df['create_time'] = current_time
df['update_time'] = current_time
print(f"成功下载 {len(df)} 条交易日历记录")
return df
except Exception as e:
print(f"下载交易日历数据时出错: {e}")
return pd.DataFrame()
# 添加保存日历数据的函数
def save_calendar_data(calendar_df, engine):
"""
将交易日历数据保存到数据库
Args:
calendar_df: 包含交易日历数据的DataFrame
engine: SQLAlchemy引擎
"""
if calendar_df.empty:
print("没有交易日历数据可保存")
return
try:
# 使用to_sql将数据写入数据库如果表存在则替换
calendar_df.to_sql('calendar', engine, if_exists='replace', index=False)
print(f"成功保存 {len(calendar_df)} 条交易日历数据到数据库")
except Exception as e:
print(f"保存交易日历数据时出错: {e}")
# 添加交易日历更新函数
def update_calendar(engine, pro):
"""
更新交易日历数据
Args:
engine: SQLAlchemy引擎
pro: Tushare API对象
"""
try:
# 检查calendar表是否存在
with engine.connect() as conn:
result = conn.execute(text("SHOW TABLES LIKE 'calendar'"))
table_exists = result.fetchone() is not None
if not table_exists:
# 如果表不存在,创建表并下载所有历史数据
create_calendar_table(engine)
calendar_df = download_calendar_data(pro)
save_calendar_data(calendar_df, engine)
else:
# 如果表存在,查询最新的交易日期
with engine.connect() as conn:
result = conn.execute(text("SELECT MAX(trade_date) FROM calendar"))
latest_date = result.fetchone()[0]
if latest_date:
# 将日期转换为YYYYMMDD格式
start_date = (datetime.strptime(str(latest_date), '%Y-%m-%d') + timedelta(days=1)).strftime('%Y%m%d')
# 下载新数据
calendar_df = download_calendar_data(pro, start_date=start_date)
if not calendar_df.empty:
# 将新数据追加到表中
calendar_df.to_sql('calendar', engine, if_exists='append', index=False)
print(f"成功更新 {len(calendar_df)} 条新的交易日历数据")
else:
print("没有新的交易日历数据需要更新")
else:
# 如果表为空,下载所有历史数据
calendar_df = download_calendar_data(pro)
save_calendar_data(calendar_df, engine)
except Exception as e:
print(f"更新交易日历时出错: {e}")
def main():
parser = argparse.ArgumentParser(description='股票数据下载器')
parser.add_argument('--init', action='store_true', help='初始化元数据')
parser.add_argument('--mode', choices=['full', 'incremental', 'both'], default='full',
help='更新模式: full(全量), incremental(增量), both(两者都)')
parser.add_argument('--year', type=int, default=2020, help='起始年份(用于全量更新)')
parser.add_argument('--processes', type=int, default=4, help='并行进程数')
parser.add_argument('--resume', action='store_true', help='是否从中断处继续')
parser.add_argument('--update-calendar', action='store_true', help='更新交易日历数据')
args = parser.parse_args()
# 加载配置
config = load_config()
# 创建数据库引擎
engine = create_engine_from_config(config)
# 设置Tushare API
ts.set_token(config['tushare_token'])
pro = ts.pro_api()
# 如果需要初始化元数据
if args.init:
# 创建元数据表
create_metadata_table(engine)
# 下载并保存股票元数据
metadata_df = download_stock_metadata(pro)
save_metadata(metadata_df, engine)
print("元数据初始化完成")
# 初始化交易日历数据
update_calendar(engine, pro)
print("交易日历初始化完成")
# 如果需要更新交易日历
if args.update_calendar:
update_calendar(engine, pro)
print("交易日历更新完成")
# 根据选择的模式执行更新
if args.mode == 'full' or args.mode == 'both':
perform_full_update(args.year, args.processes, args.resume)
if args.mode == 'incremental' or args.mode == 'both':
perform_incremental_update(args.processes)
if __name__ == '__main__':
main()