backtrader/utils.py

166 lines
5.6 KiB
Python
Raw Normal View History

2025-04-04 15:24:02 +08:00
import os
import yaml
import tushare as ts
import pandas as pd
from datetime import datetime, timedelta
def load_config():
config_path = 'config.yaml'
if not os.path.exists(config_path):
config = {
'tushare_token': 'your_token_here',
'mysql': {
'user': 'user',
'password': 'password',
'host': '127.0.0.1',
'port': 3306,
'database': 'tushare',
'charset': 'utf8'
}
}
2025-04-04 15:24:02 +08:00
with open(config_path, 'w') as f:
yaml.dump(config, f)
print(f"请在 {config_path} 中填入您的 tushare token 和 MySQL 连接信息")
2025-04-04 15:24:02 +08:00
exit(1)
2025-04-04 15:24:02 +08:00
with open(config_path, 'r') as f:
config = yaml.safe_load(f)
return config
config = load_config()
ts.set_token(config['tushare_token'])
pro = ts.pro_api()
def get_trans_data(stock_code, start_date, end_date, data_type='daily'):
# 确保日期格式正确
start_date = pd.to_datetime(start_date).strftime('%Y%m%d')
end_date = pd.to_datetime(end_date).strftime('%Y%m%d')
# 创建保存目录
save_dir = os.path.join('data', stock_code)
os.makedirs(save_dir, exist_ok=True)
# 生成文件名
file_name = f"{data_type}_{start_date}_{end_date}.csv"
file_path = os.path.join(save_dir, file_name)
# 检查文件是否已存在
if os.path.exists(file_path):
print(f"{file_path} 已存在,跳过下载")
return True
# 根据数据类型选择相应的API
if data_type == 'daily':
df = pro.daily(ts_code=stock_code, start_date=start_date, end_date=end_date)
elif data_type == 'weekly':
df = pro.weekly(ts_code=stock_code, start_date=start_date, end_date=end_date)
elif data_type == 'monthly':
df = pro.monthly(ts_code=stock_code, start_date=start_date, end_date=end_date)
elif data_type == 'money_flow':
df = pro.moneyflow(ts_code=stock_code, start_date=start_date, end_date=end_date)
elif data_type == 'daily_basic':
df = pro.daily_basic(ts_code=stock_code, start_date=start_date, end_date=end_date)
else:
print(f"不支持的数据类型: {data_type}")
return False
# 如果数据为空,返回
if df.empty:
print(f"没有找到 {stock_code}{start_date}{end_date}{data_type} 数据")
return False
# 保存数据
df.to_csv(file_path, index=False)
print(f"数据已保存到 {file_path}")
return True
def get_stock_basic():
# 获取当前时间作为文件名的一部分
current_time = datetime.now().strftime('%Y%m%d')
# 创建保存目录
save_dir = 'data'
os.makedirs(save_dir, exist_ok=True)
# 生成文件名
file_name = f"stock_basic_{current_time}.csv"
file_path = os.path.join(save_dir, file_name)
# 下载股票基本信息,包含所有可用字段
fields = 'ts_code,symbol,name,area,industry,fullname,enname,cnspell,market,exchange,curr_type,list_status,list_date,delist_date,is_hs,act_name,act_ent_type'
df = pro.stock_basic(exchange='', list_status='L', fields=fields)
# 如果数据为空,返回
if df.empty:
print("没有找到股票基本信息数据")
return
# 保存数据
df.to_csv(file_path, index=False)
print(f"股票基本信息数据已保存到 {file_path}")
def load_share_data(stock_code, data_type='daily', start_date=None, end_date=None):
data_dir = os.path.join('data', stock_code)
# 如果未指定日期,尝试加载任何可用数据
if start_date is None or end_date is None:
files = [f for f in os.listdir(data_dir) if f.startswith(f"{data_type}_")]
if not files:
raise ValueError(f"找不到{stock_code}{data_type}数据文件")
# 使用找到的第一个文件
file_path = os.path.join(data_dir, files[0])
else:
# 格式化日期
start_date = pd.to_datetime(start_date).strftime('%Y%m%d')
end_date = pd.to_datetime(end_date).strftime('%Y%m%d')
file_name = f"{data_type}_{start_date}_{end_date}.csv"
file_path = os.path.join(data_dir, file_name)
if not os.path.exists(file_path):
print(f"找不到{file_path}数据文件,执行下载")
get_trans_data(stock_code, start_date, end_date, data_type)
print(f"加载数据: {file_path}")
# 读取CSV文件
df = pd.read_csv(file_path)
# 确保日期列是datetime类型并按日期升序排序
df['trade_date'] = pd.to_datetime(df['trade_date'], format='%Y%m%d')
df = df.sort_values('trade_date')
# 重命名列以符合backtrader期望的格式
df = df.rename(columns={
'trade_date': 'datetime',
'open': 'open',
'high': 'high',
'low': 'low',
'close': 'close',
'vol': 'volume'
})
# 确保所有必需的列都存在
required_columns = ['datetime', 'open', 'high', 'low', 'close', 'volume']
for col in required_columns:
if col not in df.columns:
raise ValueError(f"缺少必需的列: {col}")
# 将datetime设为索引
df = df.set_index('datetime')
return df
if __name__ == "__main__":
stock_codes = ['002112.SZ', '601801.SH', '002737.SZ', '600970.SH'] # 示例股票代码
start_date = '20230101' # 手动输入的开始日期
end_date = '20250403' # 手动输入的结束日期
data_types = ['daily', 'weekly', 'monthly', 'money_flow', 'daily_basic'] # 可选数据类型
# 下载股票基本信息
get_stock_basic()
# 下载其他数据
for stock_code in stock_codes:
for data_type in data_types:
get_trans_data(stock_code, start_date, end_date, data_type)