backtrader/data_downloader.py

424 lines
15 KiB
Python
Raw Normal View History

import pandas as pd
import tushare as ts
import numpy as np
import os
from sqlalchemy import create_engine, text
from datetime import datetime, timedelta
from utils import load_config
def create_engine_from_config(config):
mysql = config['mysql']
connection_string = f"mysql+pymysql://{mysql['user']}:{mysql['password']}@{mysql['host']}:{mysql['port']}/{mysql['database']}?charset={mysql['charset']}&use_unicode=1"
return create_engine(connection_string)
def create_metadata_table(engine):
"""创建股票元数据表"""
# 从sql文件读取创建表的SQL
with open('sql/meta.sql', 'r', encoding='utf-8') as f:
create_table_sql = f.read()
with engine.connect() as conn:
conn.execute(text(create_table_sql))
conn.commit()
print("股票元数据表创建成功")
def create_stock_table(engine, ts_code):
"""为指定股票创建数据表"""
# 确保sql目录存在
os.makedirs('sql', exist_ok=True)
# 读取模板SQL
with open('sql/stock_table_template.sql', 'r', encoding='utf-8') as f:
template_sql = f.read()
# 替换模板中的symbol
create_table_sql = template_sql.format(symbol=ts_code.split('.')[0])
# 执行创建表的SQL
with engine.connect() as conn:
conn.execute(text(create_table_sql))
conn.commit()
print(f"股票({ts_code})数据表创建成功")
def download_stock_metadata(pro):
"""下载股票基础信息"""
df = pro.stock_basic(exchange='', list_status='L',
fields='ts_code,symbol,name,area,industry,fullname,enname,cnspell,market,exchange,curr_type,list_status,list_date,delist_date,is_hs,act_name,act_ent_type')
# 添加状态字段
df['status'] = 0 # 未初始化
df['remark'] = '自动导入'
df['last_full_update'] = None
df['last_incremental_update'] = None
df['data_start_date'] = None
df['data_end_date'] = None
df['record_count'] = 0
df['latest_price'] = None
df['latest_date'] = None
df['latest_pe_ttm'] = None
df['latest_pb'] = None
df['latest_total_mv'] = None
return df
def save_metadata(df, engine):
"""保存元数据到数据库,只更新指定字段发生变化的记录,处理日期格式问题"""
# 需要监控变化的基础元数据字段
metadata_fields = [
'ts_code', 'symbol', 'name', 'area', 'industry', 'fullname',
'enname', 'cnspell', 'market', 'exchange', 'curr_type',
'list_status', 'list_date', 'delist_date', 'is_hs',
'act_name', 'act_ent_type'
]
# 日期字段列表
date_fields = ['list_date', 'delist_date']
# 从数据库中读取现有的元数据
try:
with engine.connect() as conn:
# 只查询我们关心的字段
query = f"SELECT {', '.join(metadata_fields)} FROM stock_metadata"
existing_df = pd.read_sql(query, conn)
except Exception as e:
print(f"读取现有元数据时出错: {str(e)}")
existing_df = pd.DataFrame() # 表不存在时创建空DataFrame
if existing_df.empty:
# 如果表是空的,直接插入所有数据
result = df.to_sql('stock_metadata', engine, index=False,
if_exists='append', chunksize=1000)
print(f"成功保存{result}条股票元数据(新建)")
return result
else:
# 数据预处理:处理日期格式
df_processed = df.copy()
existing_processed = existing_df.copy()
# 标准化日期格式进行比较
def normalize_date(date_str):
if pd.isna(date_str) or date_str == '' or date_str == 'None' or date_str is None:
return ''
# 去除所有非数字字符
date_str = str(date_str).strip()
digits_only = ''.join(c for c in date_str if c.isdigit())
# 如果是8位数字形式的日期返回标准格式
if len(digits_only) == 8:
return digits_only
return ''
# 对两个DataFrame应用日期标准化
for field in date_fields:
if field in df_processed.columns and field in existing_processed.columns:
df_processed[field] = df_processed[field].apply(normalize_date)
existing_processed[field] = existing_processed[field].apply(normalize_date)
# 对其他非日期字段进行标准化处理
for col in [f for f in metadata_fields if f not in date_fields and f != 'ts_code']:
if col in df_processed.columns and col in existing_processed.columns:
# 将两个DataFrame的相同列转换为相同的数据类型字符串类型
df_processed[col] = df_processed[col].astype(str)
existing_processed[col] = existing_processed[col].astype(str)
# 处理NaN值
df_processed[col] = df_processed[col].fillna('').str.strip()
existing_processed[col] = existing_processed[col].fillna('').str.strip()
# 找出新增的记录
existing_ts_codes = set(existing_processed['ts_code'])
new_records = df[~df['ts_code'].isin(existing_ts_codes)].copy()
# 找出需要更新的记录
changed_records = []
unchanged_count = 0
# 对于每个已存在的ts_code检查是否有变化
for ts_code in df_processed[df_processed['ts_code'].isin(existing_ts_codes)]['ts_code']:
# 获取新旧记录(已经标准化后的)
new_record = df_processed[df_processed['ts_code'] == ts_code].iloc[0]
old_record = existing_processed[existing_processed['ts_code'] == ts_code].iloc[0]
# 同时取原始记录用于更新操作
original_record = df[df['ts_code'] == ts_code].iloc[0]
# 检查是否有变化
has_change = False
changed_fields = []
for field in metadata_fields:
if field == 'ts_code': # 跳过主键
continue
new_val = new_record[field]
old_val = old_record[field]
if new_val != old_val:
has_change = True
changed_fields.append(field)
print(f"发现变化 - {ts_code}{field}: '{old_val}' -> '{new_val}'")
if has_change:
changed_records.append({
'record': original_record, # 使用原始记录,保持原始格式
'changed_fields': changed_fields
})
else:
unchanged_count += 1
# 插入新记录
new_count = 0
if not new_records.empty:
new_count = new_records.to_sql('stock_metadata', engine, index=False,
if_exists='append', chunksize=1000)
# 更新变化的记录
updated_count = 0
for change_info in changed_records:
record = change_info['record']
fields = change_info['changed_fields']
# 构建更新语句,只更新变化的字段
fields_to_update = []
params = {'ts_code': record['ts_code']}
for field in fields:
fields_to_update.append(f"{field} = :{field}")
params[field] = record[field]
if fields_to_update:
update_stmt = text(f"""
UPDATE stock_metadata
SET {', '.join(fields_to_update)}
WHERE ts_code = :ts_code
""")
with engine.connect() as conn:
conn.execute(update_stmt, params)
updated_count += 1
conn.commit()
print(f"元数据更新统计:")
print(f" • 新增记录: {new_count}")
print(f" • 更新记录: {updated_count}")
print(f" • 无变化记录: {unchanged_count}")
print(f" • 总处理记录: {new_count + updated_count + unchanged_count}")
return new_count + updated_count
def download_stock_data(pro, ts_code, start_date, end_date):
"""下载股票的所有类型数据并合并"""
# 下载daily价格数据
daily_df = pro.daily(ts_code=ts_code, start_date=start_date, end_date=end_date)
if daily_df.empty:
print(f"警告:{ts_code}没有找到daily数据")
return pd.DataFrame()
# 下载daily_basic数据
daily_basic_df = pro.daily_basic(ts_code=ts_code, start_date=start_date, end_date=end_date)
# 下载moneyflow数据
moneyflow_df = pro.moneyflow(ts_code=ts_code, start_date=start_date, end_date=end_date)
# 确保每个DataFrame都有trade_date列作为合并键
if 'trade_date' not in daily_df.columns:
print(f"错误:{ts_code}的daily数据缺少trade_date列")
return pd.DataFrame()
# 为方便处理,确保所有日期列是字符串类型
daily_df['trade_date'] = daily_df['trade_date'].astype(str)
# 通过merge而不是join合并数据这样可以更好地控制列
result_df = daily_df
# 合并daily_basic数据
if not daily_basic_df.empty:
daily_basic_df['trade_date'] = daily_basic_df['trade_date'].astype(str)
# 识别重叠的列除了ts_code和trade_date
overlap_cols = list(set(result_df.columns) & set(daily_basic_df.columns) - {'ts_code', 'trade_date'})
# 从daily_basic中排除这些重叠列
daily_basic_df_filtered = daily_basic_df.drop(columns=overlap_cols)
# 合并数据
result_df = pd.merge(
result_df,
daily_basic_df_filtered,
on=['ts_code', 'trade_date'],
how='left'
)
# 合并moneyflow数据
if not moneyflow_df.empty:
moneyflow_df['trade_date'] = moneyflow_df['trade_date'].astype(str)
# 识别重叠的列除了ts_code和trade_date
overlap_cols = list(set(result_df.columns) & set(moneyflow_df.columns) - {'ts_code', 'trade_date'})
# 从moneyflow中排除这些重叠列
moneyflow_df_filtered = moneyflow_df.drop(columns=overlap_cols)
# 合并数据
result_df = pd.merge(
result_df,
moneyflow_df_filtered,
on=['ts_code', 'trade_date'],
how='left'
)
# 将trade_date转换为datetime格式
result_df['trade_date'] = pd.to_datetime(result_df['trade_date'])
return result_df
def save_stock_data(df, engine, ts_code):
"""保存股票数据到对应的表中"""
if df.empty:
print(f"警告:{ts_code}没有数据可保存")
return 0
# 删除ts_code列因为表中不需要
if 'ts_code' in df.columns:
df = df.drop(columns=['ts_code'])
# 使用replace模式确保表中只包含最新数据
try:
symbol = ts_code.split('.')[0]
result = df.to_sql(f'{symbol}', engine, index=False,
if_exists='replace', chunksize=1000)
print(f"成功保存{result}{ts_code}股票数据")
return result
except Exception as e:
print(f"保存{ts_code}数据时出错: {str(e)}")
return 0
def update_metadata(engine, ts_code):
"""更新股票元数据中的统计信息,从股票表中直接获取数据"""
# 提取股票代码(不含后缀)
symbol = ts_code.split('.')[0]
# 查询股票数据统计信息
stats_query = f"""
SELECT
MIN(trade_date) as min_date,
MAX(trade_date) as max_date,
COUNT(*) as record_count
FROM `{symbol}`
"""
# 查询最新交易日数据
latest_data_query = f"""
SELECT
close,
pe_ttm,
pb,
total_mv
FROM `{symbol}`
WHERE trade_date = (SELECT MAX(trade_date) FROM `{symbol}`)
LIMIT 1
"""
# 执行查询
with engine.connect() as conn:
# 获取统计信息
stats_result = conn.execute(text(stats_query)).fetchone()
if not stats_result:
print(f"警告:{ts_code}没有数据,无法更新元数据")
return
data_start_date = stats_result[0].strftime('%Y-%m-%d') if stats_result[0] else 'NULL'
data_end_date = stats_result[1].strftime('%Y-%m-%d') if stats_result[1] else 'NULL'
record_count = stats_result[2]
# 获取最新交易日数据
latest_data_result = conn.execute(text(latest_data_query)).fetchone()
latest_price = f"{latest_data_result[0]}" if latest_data_result and latest_data_result[
0] is not None else 'NULL'
latest_pe_ttm = f"{latest_data_result[1]}" if latest_data_result and latest_data_result[
1] is not None else 'NULL'
latest_pb = f"{latest_data_result[2]}" if latest_data_result and latest_data_result[2] is not None else 'NULL'
latest_total_mv = f"{latest_data_result[3]}" if latest_data_result and latest_data_result[
3] is not None else 'NULL'
# 加载更新SQL模板
with open('sql/update_metadata.sql', 'r', encoding='utf-8') as f:
update_sql_template = f.read()
# 填充模板
update_sql = update_sql_template.format(
data_start_date=data_start_date,
data_end_date=data_end_date,
record_count=record_count,
latest_price=latest_price,
latest_pe_ttm=latest_pe_ttm,
latest_pb=latest_pb,
latest_total_mv=latest_total_mv,
ts_code=ts_code
)
# 执行更新
conn.execute(text(update_sql))
conn.commit()
print(f"已更新({ts_code})的元数据信息")
def process_stock(engine, pro, ts_code, start_date, end_date):
"""处理单个股票的完整流程"""
# 创建该股票的表
create_stock_table(engine, ts_code)
# 下载该股票的数据
stock_data = download_stock_data(pro, ts_code, start_date, end_date)
# 保存股票数据
if not stock_data.empty:
save_stock_data(stock_data, engine, ts_code)
update_metadata(engine, ts_code)
else:
print(f"警告:({ts_code})没有获取到数据")
def main():
# 加载配置
config = load_config()
# 创建数据库引擎
engine = create_engine_from_config(config)
# 设置Tushare API
ts.set_token(config['tushare_token'])
pro = ts.pro_api()
# 创建元数据表
create_metadata_table(engine)
# 下载并保存股票元数据
metadata_df = download_stock_metadata(pro)
save_metadata(metadata_df, engine)
# 示例处理000001股票
ts_code = '000001.SZ' # 平安银行的ts_code
# 设置数据获取的日期范围示例过去5年的数据
end_date = datetime.now().strftime('%Y%m%d')
start_date = (datetime.now() - timedelta(days=15 * 365)).strftime('%Y%m%d')
# 处理该股票
process_stock(engine, pro, ts_code, start_date, end_date)
if __name__ == '__main__':
main()