backtrader/data_downloader.py

424 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pandas as pd
import tushare as ts
import numpy as np
import os
from sqlalchemy import create_engine, text
from datetime import datetime, timedelta
from utils import load_config
def create_engine_from_config(config):
mysql = config['mysql']
connection_string = f"mysql+pymysql://{mysql['user']}:{mysql['password']}@{mysql['host']}:{mysql['port']}/{mysql['database']}?charset={mysql['charset']}&use_unicode=1"
return create_engine(connection_string)
def create_metadata_table(engine):
"""创建股票元数据表"""
# 从sql文件读取创建表的SQL
with open('sql/meta.sql', 'r', encoding='utf-8') as f:
create_table_sql = f.read()
with engine.connect() as conn:
conn.execute(text(create_table_sql))
conn.commit()
print("股票元数据表创建成功")
def create_stock_table(engine, ts_code):
"""为指定股票创建数据表"""
# 确保sql目录存在
os.makedirs('sql', exist_ok=True)
# 读取模板SQL
with open('sql/stock_table_template.sql', 'r', encoding='utf-8') as f:
template_sql = f.read()
# 替换模板中的symbol
create_table_sql = template_sql.format(symbol=ts_code.split('.')[0])
# 执行创建表的SQL
with engine.connect() as conn:
conn.execute(text(create_table_sql))
conn.commit()
print(f"股票({ts_code})数据表创建成功")
def download_stock_metadata(pro):
"""下载股票基础信息"""
df = pro.stock_basic(exchange='', list_status='L',
fields='ts_code,symbol,name,area,industry,fullname,enname,cnspell,market,exchange,curr_type,list_status,list_date,delist_date,is_hs,act_name,act_ent_type')
# 添加状态字段
df['status'] = 0 # 未初始化
df['remark'] = '自动导入'
df['last_full_update'] = None
df['last_incremental_update'] = None
df['data_start_date'] = None
df['data_end_date'] = None
df['record_count'] = 0
df['latest_price'] = None
df['latest_date'] = None
df['latest_pe_ttm'] = None
df['latest_pb'] = None
df['latest_total_mv'] = None
return df
def save_metadata(df, engine):
"""保存元数据到数据库,只更新指定字段发生变化的记录,处理日期格式问题"""
# 需要监控变化的基础元数据字段
metadata_fields = [
'ts_code', 'symbol', 'name', 'area', 'industry', 'fullname',
'enname', 'cnspell', 'market', 'exchange', 'curr_type',
'list_status', 'list_date', 'delist_date', 'is_hs',
'act_name', 'act_ent_type'
]
# 日期字段列表
date_fields = ['list_date', 'delist_date']
# 从数据库中读取现有的元数据
try:
with engine.connect() as conn:
# 只查询我们关心的字段
query = f"SELECT {', '.join(metadata_fields)} FROM stock_metadata"
existing_df = pd.read_sql(query, conn)
except Exception as e:
print(f"读取现有元数据时出错: {str(e)}")
existing_df = pd.DataFrame() # 表不存在时创建空DataFrame
if existing_df.empty:
# 如果表是空的,直接插入所有数据
result = df.to_sql('stock_metadata', engine, index=False,
if_exists='append', chunksize=1000)
print(f"成功保存{result}条股票元数据(新建)")
return result
else:
# 数据预处理:处理日期格式
df_processed = df.copy()
existing_processed = existing_df.copy()
# 标准化日期格式进行比较
def normalize_date(date_str):
if pd.isna(date_str) or date_str == '' or date_str == 'None' or date_str is None:
return ''
# 去除所有非数字字符
date_str = str(date_str).strip()
digits_only = ''.join(c for c in date_str if c.isdigit())
# 如果是8位数字形式的日期返回标准格式
if len(digits_only) == 8:
return digits_only
return ''
# 对两个DataFrame应用日期标准化
for field in date_fields:
if field in df_processed.columns and field in existing_processed.columns:
df_processed[field] = df_processed[field].apply(normalize_date)
existing_processed[field] = existing_processed[field].apply(normalize_date)
# 对其他非日期字段进行标准化处理
for col in [f for f in metadata_fields if f not in date_fields and f != 'ts_code']:
if col in df_processed.columns and col in existing_processed.columns:
# 将两个DataFrame的相同列转换为相同的数据类型字符串类型
df_processed[col] = df_processed[col].astype(str)
existing_processed[col] = existing_processed[col].astype(str)
# 处理NaN值
df_processed[col] = df_processed[col].fillna('').str.strip()
existing_processed[col] = existing_processed[col].fillna('').str.strip()
# 找出新增的记录
existing_ts_codes = set(existing_processed['ts_code'])
new_records = df[~df['ts_code'].isin(existing_ts_codes)].copy()
# 找出需要更新的记录
changed_records = []
unchanged_count = 0
# 对于每个已存在的ts_code检查是否有变化
for ts_code in df_processed[df_processed['ts_code'].isin(existing_ts_codes)]['ts_code']:
# 获取新旧记录(已经标准化后的)
new_record = df_processed[df_processed['ts_code'] == ts_code].iloc[0]
old_record = existing_processed[existing_processed['ts_code'] == ts_code].iloc[0]
# 同时取原始记录用于更新操作
original_record = df[df['ts_code'] == ts_code].iloc[0]
# 检查是否有变化
has_change = False
changed_fields = []
for field in metadata_fields:
if field == 'ts_code': # 跳过主键
continue
new_val = new_record[field]
old_val = old_record[field]
if new_val != old_val:
has_change = True
changed_fields.append(field)
print(f"发现变化 - {ts_code}{field}: '{old_val}' -> '{new_val}'")
if has_change:
changed_records.append({
'record': original_record, # 使用原始记录,保持原始格式
'changed_fields': changed_fields
})
else:
unchanged_count += 1
# 插入新记录
new_count = 0
if not new_records.empty:
new_count = new_records.to_sql('stock_metadata', engine, index=False,
if_exists='append', chunksize=1000)
# 更新变化的记录
updated_count = 0
for change_info in changed_records:
record = change_info['record']
fields = change_info['changed_fields']
# 构建更新语句,只更新变化的字段
fields_to_update = []
params = {'ts_code': record['ts_code']}
for field in fields:
fields_to_update.append(f"{field} = :{field}")
params[field] = record[field]
if fields_to_update:
update_stmt = text(f"""
UPDATE stock_metadata
SET {', '.join(fields_to_update)}
WHERE ts_code = :ts_code
""")
with engine.connect() as conn:
conn.execute(update_stmt, params)
updated_count += 1
conn.commit()
print(f"元数据更新统计:")
print(f" • 新增记录: {new_count}")
print(f" • 更新记录: {updated_count}")
print(f" • 无变化记录: {unchanged_count}")
print(f" • 总处理记录: {new_count + updated_count + unchanged_count}")
return new_count + updated_count
def download_stock_data(pro, ts_code, start_date, end_date):
"""下载股票的所有类型数据并合并"""
# 下载daily价格数据
daily_df = pro.daily(ts_code=ts_code, start_date=start_date, end_date=end_date)
if daily_df.empty:
print(f"警告:{ts_code}没有找到daily数据")
return pd.DataFrame()
# 下载daily_basic数据
daily_basic_df = pro.daily_basic(ts_code=ts_code, start_date=start_date, end_date=end_date)
# 下载moneyflow数据
moneyflow_df = pro.moneyflow(ts_code=ts_code, start_date=start_date, end_date=end_date)
# 确保每个DataFrame都有trade_date列作为合并键
if 'trade_date' not in daily_df.columns:
print(f"错误:{ts_code}的daily数据缺少trade_date列")
return pd.DataFrame()
# 为方便处理,确保所有日期列是字符串类型
daily_df['trade_date'] = daily_df['trade_date'].astype(str)
# 通过merge而不是join合并数据这样可以更好地控制列
result_df = daily_df
# 合并daily_basic数据
if not daily_basic_df.empty:
daily_basic_df['trade_date'] = daily_basic_df['trade_date'].astype(str)
# 识别重叠的列除了ts_code和trade_date
overlap_cols = list(set(result_df.columns) & set(daily_basic_df.columns) - {'ts_code', 'trade_date'})
# 从daily_basic中排除这些重叠列
daily_basic_df_filtered = daily_basic_df.drop(columns=overlap_cols)
# 合并数据
result_df = pd.merge(
result_df,
daily_basic_df_filtered,
on=['ts_code', 'trade_date'],
how='left'
)
# 合并moneyflow数据
if not moneyflow_df.empty:
moneyflow_df['trade_date'] = moneyflow_df['trade_date'].astype(str)
# 识别重叠的列除了ts_code和trade_date
overlap_cols = list(set(result_df.columns) & set(moneyflow_df.columns) - {'ts_code', 'trade_date'})
# 从moneyflow中排除这些重叠列
moneyflow_df_filtered = moneyflow_df.drop(columns=overlap_cols)
# 合并数据
result_df = pd.merge(
result_df,
moneyflow_df_filtered,
on=['ts_code', 'trade_date'],
how='left'
)
# 将trade_date转换为datetime格式
result_df['trade_date'] = pd.to_datetime(result_df['trade_date'])
return result_df
def save_stock_data(df, engine, ts_code):
"""保存股票数据到对应的表中"""
if df.empty:
print(f"警告:{ts_code}没有数据可保存")
return 0
# 删除ts_code列因为表中不需要
if 'ts_code' in df.columns:
df = df.drop(columns=['ts_code'])
# 使用replace模式确保表中只包含最新数据
try:
symbol = ts_code.split('.')[0]
result = df.to_sql(f'{symbol}', engine, index=False,
if_exists='replace', chunksize=1000)
print(f"成功保存{result}{ts_code}股票数据")
return result
except Exception as e:
print(f"保存{ts_code}数据时出错: {str(e)}")
return 0
def update_metadata(engine, ts_code):
"""更新股票元数据中的统计信息,从股票表中直接获取数据"""
# 提取股票代码(不含后缀)
symbol = ts_code.split('.')[0]
# 查询股票数据统计信息
stats_query = f"""
SELECT
MIN(trade_date) as min_date,
MAX(trade_date) as max_date,
COUNT(*) as record_count
FROM `{symbol}`
"""
# 查询最新交易日数据
latest_data_query = f"""
SELECT
close,
pe_ttm,
pb,
total_mv
FROM `{symbol}`
WHERE trade_date = (SELECT MAX(trade_date) FROM `{symbol}`)
LIMIT 1
"""
# 执行查询
with engine.connect() as conn:
# 获取统计信息
stats_result = conn.execute(text(stats_query)).fetchone()
if not stats_result:
print(f"警告:{ts_code}没有数据,无法更新元数据")
return
data_start_date = stats_result[0].strftime('%Y-%m-%d') if stats_result[0] else 'NULL'
data_end_date = stats_result[1].strftime('%Y-%m-%d') if stats_result[1] else 'NULL'
record_count = stats_result[2]
# 获取最新交易日数据
latest_data_result = conn.execute(text(latest_data_query)).fetchone()
latest_price = f"{latest_data_result[0]}" if latest_data_result and latest_data_result[
0] is not None else 'NULL'
latest_pe_ttm = f"{latest_data_result[1]}" if latest_data_result and latest_data_result[
1] is not None else 'NULL'
latest_pb = f"{latest_data_result[2]}" if latest_data_result and latest_data_result[2] is not None else 'NULL'
latest_total_mv = f"{latest_data_result[3]}" if latest_data_result and latest_data_result[
3] is not None else 'NULL'
# 加载更新SQL模板
with open('sql/update_metadata.sql', 'r', encoding='utf-8') as f:
update_sql_template = f.read()
# 填充模板
update_sql = update_sql_template.format(
data_start_date=data_start_date,
data_end_date=data_end_date,
record_count=record_count,
latest_price=latest_price,
latest_pe_ttm=latest_pe_ttm,
latest_pb=latest_pb,
latest_total_mv=latest_total_mv,
ts_code=ts_code
)
# 执行更新
conn.execute(text(update_sql))
conn.commit()
print(f"已更新({ts_code})的元数据信息")
def process_stock(engine, pro, ts_code, start_date, end_date):
"""处理单个股票的完整流程"""
# 创建该股票的表
create_stock_table(engine, ts_code)
# 下载该股票的数据
stock_data = download_stock_data(pro, ts_code, start_date, end_date)
# 保存股票数据
if not stock_data.empty:
save_stock_data(stock_data, engine, ts_code)
update_metadata(engine, ts_code)
else:
print(f"警告:({ts_code})没有获取到数据")
def main():
# 加载配置
config = load_config()
# 创建数据库引擎
engine = create_engine_from_config(config)
# 设置Tushare API
ts.set_token(config['tushare_token'])
pro = ts.pro_api()
# 创建元数据表
create_metadata_table(engine)
# 下载并保存股票元数据
metadata_df = download_stock_metadata(pro)
save_metadata(metadata_df, engine)
# 示例处理000001股票
ts_code = '000001.SZ' # 平安银行的ts_code
# 设置数据获取的日期范围示例过去5年的数据
end_date = datetime.now().strftime('%Y%m%d')
start_date = (datetime.now() - timedelta(days=15 * 365)).strftime('%Y%m%d')
# 处理该股票
process_stock(engine, pro, ts_code, start_date, end_date)
if __name__ == '__main__':
main()