2025-04-04 15:24:02 +08:00
|
|
|
import os
|
|
|
|
import yaml
|
|
|
|
import tushare as ts
|
|
|
|
import pandas as pd
|
|
|
|
from datetime import datetime, timedelta
|
|
|
|
|
2025-04-05 22:48:31 +08:00
|
|
|
from sqlalchemy import create_engine
|
|
|
|
|
2025-04-04 15:24:02 +08:00
|
|
|
|
|
|
|
def load_config():
|
|
|
|
config_path = 'config.yaml'
|
|
|
|
if not os.path.exists(config_path):
|
2025-04-05 01:02:15 +08:00
|
|
|
config = {
|
|
|
|
'tushare_token': 'your_token_here',
|
|
|
|
'mysql': {
|
|
|
|
'user': 'user',
|
|
|
|
'password': 'password',
|
|
|
|
'host': '127.0.0.1',
|
|
|
|
'port': 3306,
|
|
|
|
'database': 'tushare',
|
|
|
|
'charset': 'utf8'
|
|
|
|
}
|
|
|
|
}
|
2025-04-04 15:24:02 +08:00
|
|
|
with open(config_path, 'w') as f:
|
|
|
|
yaml.dump(config, f)
|
2025-04-05 01:02:15 +08:00
|
|
|
print(f"请在 {config_path} 中填入您的 tushare token 和 MySQL 连接信息")
|
2025-04-04 15:24:02 +08:00
|
|
|
exit(1)
|
2025-04-05 01:02:15 +08:00
|
|
|
|
2025-04-04 15:24:02 +08:00
|
|
|
with open(config_path, 'r') as f:
|
|
|
|
config = yaml.safe_load(f)
|
|
|
|
return config
|
|
|
|
|
|
|
|
|
|
|
|
config = load_config()
|
|
|
|
ts.set_token(config['tushare_token'])
|
|
|
|
pro = ts.pro_api()
|
|
|
|
|
|
|
|
|
|
|
|
def get_trans_data(stock_code, start_date, end_date, data_type='daily'):
|
|
|
|
# 确保日期格式正确
|
|
|
|
start_date = pd.to_datetime(start_date).strftime('%Y%m%d')
|
|
|
|
end_date = pd.to_datetime(end_date).strftime('%Y%m%d')
|
|
|
|
# 创建保存目录
|
|
|
|
save_dir = os.path.join('data', stock_code)
|
|
|
|
os.makedirs(save_dir, exist_ok=True)
|
|
|
|
# 生成文件名
|
|
|
|
file_name = f"{data_type}_{start_date}_{end_date}.csv"
|
|
|
|
file_path = os.path.join(save_dir, file_name)
|
|
|
|
# 检查文件是否已存在
|
|
|
|
if os.path.exists(file_path):
|
|
|
|
print(f"{file_path} 已存在,跳过下载")
|
|
|
|
return True
|
|
|
|
# 根据数据类型选择相应的API
|
|
|
|
if data_type == 'daily':
|
|
|
|
df = pro.daily(ts_code=stock_code, start_date=start_date, end_date=end_date)
|
|
|
|
elif data_type == 'weekly':
|
|
|
|
df = pro.weekly(ts_code=stock_code, start_date=start_date, end_date=end_date)
|
|
|
|
elif data_type == 'monthly':
|
|
|
|
df = pro.monthly(ts_code=stock_code, start_date=start_date, end_date=end_date)
|
|
|
|
elif data_type == 'money_flow':
|
|
|
|
df = pro.moneyflow(ts_code=stock_code, start_date=start_date, end_date=end_date)
|
|
|
|
elif data_type == 'daily_basic':
|
|
|
|
df = pro.daily_basic(ts_code=stock_code, start_date=start_date, end_date=end_date)
|
|
|
|
else:
|
|
|
|
print(f"不支持的数据类型: {data_type}")
|
|
|
|
return False
|
|
|
|
# 如果数据为空,返回
|
|
|
|
if df.empty:
|
|
|
|
print(f"没有找到 {stock_code} 从 {start_date} 到 {end_date} 的 {data_type} 数据")
|
|
|
|
return False
|
|
|
|
# 保存数据
|
|
|
|
df.to_csv(file_path, index=False)
|
|
|
|
print(f"数据已保存到 {file_path}")
|
|
|
|
return True
|
|
|
|
|
|
|
|
def get_stock_basic():
|
|
|
|
# 获取当前时间作为文件名的一部分
|
|
|
|
current_time = datetime.now().strftime('%Y%m%d')
|
|
|
|
|
|
|
|
# 创建保存目录
|
|
|
|
save_dir = 'data'
|
|
|
|
os.makedirs(save_dir, exist_ok=True)
|
|
|
|
|
|
|
|
# 生成文件名
|
|
|
|
file_name = f"stock_basic_{current_time}.csv"
|
|
|
|
file_path = os.path.join(save_dir, file_name)
|
|
|
|
|
|
|
|
# 下载股票基本信息,包含所有可用字段
|
|
|
|
fields = 'ts_code,symbol,name,area,industry,fullname,enname,cnspell,market,exchange,curr_type,list_status,list_date,delist_date,is_hs,act_name,act_ent_type'
|
|
|
|
df = pro.stock_basic(exchange='', list_status='L', fields=fields)
|
|
|
|
|
|
|
|
# 如果数据为空,返回
|
|
|
|
if df.empty:
|
|
|
|
print("没有找到股票基本信息数据")
|
|
|
|
return
|
|
|
|
|
|
|
|
# 保存数据
|
|
|
|
df.to_csv(file_path, index=False)
|
|
|
|
print(f"股票基本信息数据已保存到 {file_path}")
|
|
|
|
|
|
|
|
def load_share_data(stock_code, data_type='daily', start_date=None, end_date=None):
|
|
|
|
|
|
|
|
data_dir = os.path.join('data', stock_code)
|
|
|
|
|
|
|
|
# 如果未指定日期,尝试加载任何可用数据
|
|
|
|
if start_date is None or end_date is None:
|
|
|
|
files = [f for f in os.listdir(data_dir) if f.startswith(f"{data_type}_")]
|
|
|
|
if not files:
|
|
|
|
raise ValueError(f"找不到{stock_code}的{data_type}数据文件")
|
|
|
|
|
|
|
|
# 使用找到的第一个文件
|
|
|
|
file_path = os.path.join(data_dir, files[0])
|
|
|
|
else:
|
|
|
|
# 格式化日期
|
|
|
|
start_date = pd.to_datetime(start_date).strftime('%Y%m%d')
|
|
|
|
end_date = pd.to_datetime(end_date).strftime('%Y%m%d')
|
|
|
|
file_name = f"{data_type}_{start_date}_{end_date}.csv"
|
|
|
|
file_path = os.path.join(data_dir, file_name)
|
|
|
|
|
|
|
|
if not os.path.exists(file_path):
|
|
|
|
print(f"找不到{file_path}数据文件,执行下载")
|
|
|
|
get_trans_data(stock_code, start_date, end_date, data_type)
|
|
|
|
|
|
|
|
print(f"加载数据: {file_path}")
|
|
|
|
|
|
|
|
# 读取CSV文件
|
|
|
|
df = pd.read_csv(file_path)
|
|
|
|
|
|
|
|
# 确保日期列是datetime类型并按日期升序排序
|
|
|
|
df['trade_date'] = pd.to_datetime(df['trade_date'], format='%Y%m%d')
|
|
|
|
df = df.sort_values('trade_date')
|
|
|
|
|
|
|
|
# 重命名列以符合backtrader期望的格式
|
|
|
|
df = df.rename(columns={
|
|
|
|
'trade_date': 'datetime',
|
|
|
|
'open': 'open',
|
|
|
|
'high': 'high',
|
|
|
|
'low': 'low',
|
|
|
|
'close': 'close',
|
|
|
|
'vol': 'volume'
|
|
|
|
})
|
|
|
|
|
|
|
|
# 确保所有必需的列都存在
|
|
|
|
required_columns = ['datetime', 'open', 'high', 'low', 'close', 'volume']
|
|
|
|
for col in required_columns:
|
|
|
|
if col not in df.columns:
|
|
|
|
raise ValueError(f"缺少必需的列: {col}")
|
|
|
|
|
|
|
|
# 将datetime设为索引
|
|
|
|
df = df.set_index('datetime')
|
|
|
|
|
|
|
|
return df
|
|
|
|
|
2025-04-05 22:48:31 +08:00
|
|
|
def create_engine_from_config(config):
|
|
|
|
mysql = config['mysql']
|
|
|
|
connection_string = f"mysql+pymysql://{mysql['user']}:{mysql['password']}@{mysql['host']}:{mysql['port']}/{mysql['database']}?charset={mysql['charset']}&use_unicode=1"
|
|
|
|
return create_engine(connection_string)
|
|
|
|
|
2025-04-04 15:24:02 +08:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
stock_codes = ['002112.SZ', '601801.SH', '002737.SZ', '600970.SH'] # 示例股票代码
|
|
|
|
start_date = '20230101' # 手动输入的开始日期
|
|
|
|
end_date = '20250403' # 手动输入的结束日期
|
|
|
|
data_types = ['daily', 'weekly', 'monthly', 'money_flow', 'daily_basic'] # 可选数据类型
|
|
|
|
|
|
|
|
# 下载股票基本信息
|
|
|
|
get_stock_basic()
|
|
|
|
|
|
|
|
# 下载其他数据
|
|
|
|
for stock_code in stock_codes:
|
|
|
|
for data_type in data_types:
|
|
|
|
get_trans_data(stock_code, start_date, end_date, data_type)
|