🗑️ delete(test.py): 删除策略回测代码文件
🧹 cleanup(utils.py): 删除无用的获取股票数据的函数
This commit is contained in:
parent
0b736e6db7
commit
995d724781
234
test.py
234
test.py
@ -1,234 +0,0 @@
|
|||||||
import os
|
|
||||||
import pandas as pd
|
|
||||||
import backtrader as bt
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
|
|
||||||
from utils import load_share_data
|
|
||||||
|
|
||||||
|
|
||||||
# 自定义多头趋势策略
|
|
||||||
class BullTrendStrategy(bt.Strategy):
|
|
||||||
params = (
|
|
||||||
('ma5_period', 5), # 短周期移动平均线
|
|
||||||
('ma10_period', 10), # 中周期移动平均线
|
|
||||||
('ma20_period', 20), # 长周期移动平均线
|
|
||||||
('min_lot', 100), # 最小交易手数(一手)
|
|
||||||
('stop_loss_pct', 5.0), # 止损百分比,5%
|
|
||||||
('max_gain_pct', 5.0), # 最大涨幅限制
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
# 初始化移动平均线指标
|
|
||||||
self.ma5 = bt.indicators.SMA(self.data.close, period=self.params.ma5_period)
|
|
||||||
self.ma10 = bt.indicators.SMA(self.data.close, period=self.params.ma10_period)
|
|
||||||
self.ma20 = bt.indicators.SMA(self.data.close, period=self.params.ma20_period)
|
|
||||||
|
|
||||||
# 记录多头趋势信号 (MA5 > MA10 > MA20)
|
|
||||||
self.bull_trend = bt.indicators.And(self.ma5 > self.ma10, self.ma10 > self.ma20)
|
|
||||||
|
|
||||||
# 记录MA10 > MA20的条件
|
|
||||||
self.ma10_gt_ma20 = self.ma10 > self.ma20
|
|
||||||
|
|
||||||
# 订单和买入价格
|
|
||||||
self.order = None
|
|
||||||
self.buy_price = None
|
|
||||||
|
|
||||||
# 添加止损相关变量
|
|
||||||
self.signal_low = None # 出现多头趋势信号时的最低价
|
|
||||||
self.stop_loss_price = None # 止损价格
|
|
||||||
|
|
||||||
# 日志
|
|
||||||
self.log(
|
|
||||||
f"策略初始化: MA5={self.params.ma5_period}, MA10={self.params.ma10_period}, MA20={self.params.ma20_period}, 止损={self.params.stop_loss_pct}%, 涨幅限制={self.params.max_gain_pct}%")
|
|
||||||
|
|
||||||
def log(self, txt, dt=None):
|
|
||||||
"""记录策略日志"""
|
|
||||||
dt = dt or self.datas[0].datetime.date(0)
|
|
||||||
print(f'{dt.isoformat()}, {txt}')
|
|
||||||
|
|
||||||
def notify_order(self, order):
|
|
||||||
if order.status in [order.Submitted, order.Accepted]:
|
|
||||||
# 订单已提交/已接受,无需操作
|
|
||||||
return
|
|
||||||
|
|
||||||
# 检查订单是否已完成
|
|
||||||
if order.status in [order.Completed]:
|
|
||||||
if order.isbuy():
|
|
||||||
self.buy_price = order.executed.price
|
|
||||||
self.log(
|
|
||||||
f'买入执行: 价格={order.executed.price:.2f}, 数量={order.executed.size}, 成本={order.executed.value:.2f}, 手续费={order.executed.comm:.2f}')
|
|
||||||
else: # 卖出
|
|
||||||
profit = (order.executed.price - self.buy_price) * order.executed.size
|
|
||||||
profit_pct = (order.executed.price - self.buy_price) / self.buy_price * 100
|
|
||||||
self.log(
|
|
||||||
f'卖出执行: 价格={order.executed.price:.2f}, 数量={order.executed.size}, 盈亏={profit:.2f}元 ({profit_pct:.2f}%), 手续费={order.executed.comm:.2f}')
|
|
||||||
# 清除止损价格和信号最低价
|
|
||||||
self.stop_loss_price = None
|
|
||||||
self.signal_low = None
|
|
||||||
self.buy_price = None
|
|
||||||
|
|
||||||
elif order.status in [order.Canceled, order.Margin, order.Rejected]:
|
|
||||||
self.log(f'订单取消/拒绝/保证金不足: {order.status}')
|
|
||||||
|
|
||||||
# 重置订单
|
|
||||||
self.order = None
|
|
||||||
|
|
||||||
def next(self):
|
|
||||||
# 如果有未完成的订单,不进行操作
|
|
||||||
if self.order:
|
|
||||||
return
|
|
||||||
|
|
||||||
# 检查是否已持仓
|
|
||||||
if not self.position:
|
|
||||||
# 没有持仓 - 检查是否符合买入条件
|
|
||||||
# 条件1:T日是多头趋势,且T-1日是MA10>MA20但MA5不满足的状态
|
|
||||||
if (self.bull_trend[0] and # T日是多头趋势
|
|
||||||
self.ma10_gt_ma20[-1] and # T-1日MA10>MA20
|
|
||||||
not self.bull_trend[-1]): # T-1日不是完全多头趋势(即MA5不满足)
|
|
||||||
|
|
||||||
# 条件2:T日为阳线(收盘价>开盘价)
|
|
||||||
is_bullish = self.data.close[0] > self.data.open[0]
|
|
||||||
|
|
||||||
# 条件3:T日涨幅不超过5%
|
|
||||||
daily_gain_pct = (self.data.close[0] - self.data.close[-1]) / self.data.close[-1] * 100
|
|
||||||
valid_gain = daily_gain_pct <= self.params.max_gain_pct
|
|
||||||
|
|
||||||
# 所有条件满足才进行买入
|
|
||||||
if is_bullish and valid_gain:
|
|
||||||
# 记录出现多头趋势信号时的最低价
|
|
||||||
self.signal_low = self.data.low[0]
|
|
||||||
# 计算止损价格(比信号出现时最低价低5%)
|
|
||||||
self.stop_loss_price = self.signal_low * (1 - self.params.stop_loss_pct / 100)
|
|
||||||
|
|
||||||
self.log(f'买入信号! T日多头趋势,T-1日MA10>MA20但MA5不满足')
|
|
||||||
self.log(f'当前MA5={self.ma5[0]:.2f}, MA10={self.ma10[0]:.2f}, MA20={self.ma20[0]:.2f}')
|
|
||||||
self.log(f'昨日MA5={self.ma5[-1]:.2f}, MA10={self.ma10[-1]:.2f}, MA20={self.ma20[-1]:.2f}')
|
|
||||||
self.log(f'K线为阳线,当日涨幅={daily_gain_pct:.2f}%,在限制{self.params.max_gain_pct}%以内')
|
|
||||||
self.log(
|
|
||||||
f'设置止损价格: {self.stop_loss_price:.2f} (信号最低价{self.signal_low:.2f}的{self.params.stop_loss_pct}%)')
|
|
||||||
|
|
||||||
# 计算可以买入的股数(必须是100的整数倍)
|
|
||||||
available_cash = self.broker.getcash() * 0.95 # 保留5%现金
|
|
||||||
price = self.data.close[0]
|
|
||||||
size = int(available_cash / price / self.params.min_lot) * self.params.min_lot
|
|
||||||
|
|
||||||
if size >= self.params.min_lot:
|
|
||||||
self.log(f'设置购买订单,下一个开盘价,数量={size}股')
|
|
||||||
# 下一个bar开盘价买入
|
|
||||||
self.order = self.buy(size=size, exectype=bt.Order.Market)
|
|
||||||
else:
|
|
||||||
self.log(
|
|
||||||
f'资金不足,无法购买最小手数: 需要{self.params.min_lot}股,当前资金只能买{int(available_cash / price)}股')
|
|
||||||
elif self.bull_trend[0]:
|
|
||||||
# 显示为什么不买入的原因
|
|
||||||
if not is_bullish:
|
|
||||||
self.log(
|
|
||||||
f'符合多头趋势但不买入: K线不是阳线 (开盘={self.data.open[0]:.2f}, 收盘={self.data.close[0]:.2f})')
|
|
||||||
elif not valid_gain:
|
|
||||||
self.log(
|
|
||||||
f'符合多头趋势但不买入: 当日涨幅={daily_gain_pct:.2f}%,超过限制{self.params.max_gain_pct}%')
|
|
||||||
else:
|
|
||||||
# 已经持仓 - 检查是否应该卖出
|
|
||||||
# 检查止损条件 - 如果当前最低价低于止损价格
|
|
||||||
if self.data.low[0] <= self.stop_loss_price:
|
|
||||||
self.log(f'触发止损! 当前最低价={self.data.low[0]:.2f} 低于止损价={self.stop_loss_price:.2f}')
|
|
||||||
self.log(f'设置止损卖出订单,下一个开盘价')
|
|
||||||
# 下一个bar开盘价卖出全部持仓
|
|
||||||
self.order = self.sell(size=self.position.size, exectype=bt.Order.Market)
|
|
||||||
|
|
||||||
# 当T日不是多头趋势时,T+1日开盘卖出
|
|
||||||
elif not self.bull_trend[0]:
|
|
||||||
self.log(f'卖出信号! 当前非多头趋势')
|
|
||||||
self.log(f'MA5={self.ma5[0]:.2f}, MA10={self.ma10[0]:.2f}, MA20={self.ma20[0]:.2f}')
|
|
||||||
self.log(f'设置卖出订单,下一个开盘价')
|
|
||||||
# 下一个bar开盘价卖出全部持仓
|
|
||||||
self.order = self.sell(size=self.position.size, exectype=bt.Order.Market)
|
|
||||||
|
|
||||||
def run_backtest(stock_code, start_date=None, end_date=None, ma5_period=5, ma10_period=10, ma20_period=20,
|
|
||||||
initial_cash=100000):
|
|
||||||
# 创建cerebro引擎
|
|
||||||
cerebro = bt.Cerebro()
|
|
||||||
|
|
||||||
# 添加策略
|
|
||||||
cerebro.addstrategy(BullTrendStrategy,
|
|
||||||
ma5_period=ma5_period,
|
|
||||||
ma10_period=ma10_period,
|
|
||||||
ma20_period=ma20_period)
|
|
||||||
|
|
||||||
# 加载数据
|
|
||||||
df = load_share_data(stock_code, 'daily', start_date, end_date)
|
|
||||||
|
|
||||||
# 创建数据源
|
|
||||||
data = bt.feeds.PandasData(dataname=df)
|
|
||||||
|
|
||||||
# 添加数据到cerebro
|
|
||||||
cerebro.adddata(data)
|
|
||||||
|
|
||||||
# 设置初始资金
|
|
||||||
cerebro.broker.setcash(initial_cash)
|
|
||||||
|
|
||||||
# 设置手续费 (0.1%)
|
|
||||||
cerebro.broker.setcommission(commission=0.001)
|
|
||||||
|
|
||||||
# 设置滑点 (0.1%)
|
|
||||||
cerebro.broker.set_slippage_perc(0.001)
|
|
||||||
|
|
||||||
# 添加分析器
|
|
||||||
cerebro.addanalyzer(bt.analyzers.SharpeRatio, _name='sharpe')
|
|
||||||
cerebro.addanalyzer(bt.analyzers.DrawDown, _name='drawdown')
|
|
||||||
cerebro.addanalyzer(bt.analyzers.Returns, _name='returns')
|
|
||||||
cerebro.addanalyzer(bt.analyzers.TradeAnalyzer, _name='trades')
|
|
||||||
|
|
||||||
# 显示起始资金
|
|
||||||
print(f'起始资金: {cerebro.broker.getvalue():.2f}')
|
|
||||||
|
|
||||||
# 运行回测
|
|
||||||
results = cerebro.run()
|
|
||||||
strategy = results[0]
|
|
||||||
|
|
||||||
# 显示结束资金
|
|
||||||
final_value = cerebro.broker.getvalue()
|
|
||||||
print(f'结束资金: {final_value:.2f}')
|
|
||||||
print(f'总收益率: {(final_value - initial_cash) / initial_cash * 100:.2f}%')
|
|
||||||
|
|
||||||
# 显示分析结果
|
|
||||||
print(f'夏普比率: {strategy.analyzers.sharpe.get_analysis()["sharperatio"]:.3f}')
|
|
||||||
print(f'最大回撤: {strategy.analyzers.drawdown.get_analysis()["max"]["drawdown"]:.2f}%')
|
|
||||||
print(f'年化收益率: {strategy.analyzers.returns.get_analysis()["rnorm100"]:.2f}%')
|
|
||||||
|
|
||||||
# 交易统计
|
|
||||||
trade_analysis = strategy.analyzers.trades.get_analysis()
|
|
||||||
if trade_analysis.get('total', {}).get('total', 0) > 0:
|
|
||||||
print(f'总交易次数: {trade_analysis["total"]["total"]}')
|
|
||||||
print(f'盈利交易: {trade_analysis.get("won", {}).get("total", 0)}')
|
|
||||||
print(f'亏损交易: {trade_analysis.get("lost", {}).get("total", 0)}')
|
|
||||||
if trade_analysis.get("won", {}).get("total", 0) > 0:
|
|
||||||
print(f'平均盈利: {trade_analysis["won"]["pnl"]["average"]:.2f}')
|
|
||||||
if trade_analysis.get("lost", {}).get("total", 0) > 0:
|
|
||||||
print(f'平均亏损: {trade_analysis["lost"]["pnl"]["average"]:.2f}')
|
|
||||||
|
|
||||||
# 绘制结果
|
|
||||||
cerebro.plot(style='candle', figsize=(20, 10), barup='red', bardown='green')
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# 回测参数
|
|
||||||
stock_code = '002737.SZ' # 指定股票代码
|
|
||||||
start_date = '20200101' # 开始日期
|
|
||||||
end_date = '20250403' # 结束日期 (请注意使用实际下载数据的日期范围)
|
|
||||||
ma5_period = 5 # MA5周期
|
|
||||||
ma10_period = 10 # MA10周期
|
|
||||||
ma20_period = 20 # MA20周期
|
|
||||||
initial_cash = 100000 # 初始资金
|
|
||||||
|
|
||||||
# 运行回测
|
|
||||||
run_backtest(
|
|
||||||
stock_code=stock_code,
|
|
||||||
start_date=start_date,
|
|
||||||
end_date=end_date,
|
|
||||||
ma5_period=ma5_period,
|
|
||||||
ma10_period=ma10_period,
|
|
||||||
ma20_period=ma20_period,
|
|
||||||
initial_cash=initial_cash
|
|
||||||
)
|
|
136
utils.py
136
utils.py
@ -31,143 +31,7 @@ def load_config():
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
config = load_config()
|
|
||||||
ts.set_token(config['tushare_token'])
|
|
||||||
pro = ts.pro_api()
|
|
||||||
|
|
||||||
|
|
||||||
def get_trans_data(stock_code, start_date, end_date, data_type='daily'):
|
|
||||||
# 确保日期格式正确
|
|
||||||
start_date = pd.to_datetime(start_date).strftime('%Y%m%d')
|
|
||||||
end_date = pd.to_datetime(end_date).strftime('%Y%m%d')
|
|
||||||
# 创建保存目录
|
|
||||||
save_dir = os.path.join('data', stock_code)
|
|
||||||
os.makedirs(save_dir, exist_ok=True)
|
|
||||||
# 生成文件名
|
|
||||||
file_name = f"{data_type}_{start_date}_{end_date}.csv"
|
|
||||||
file_path = os.path.join(save_dir, file_name)
|
|
||||||
# 检查文件是否已存在
|
|
||||||
if os.path.exists(file_path):
|
|
||||||
print(f"{file_path} 已存在,跳过下载")
|
|
||||||
return True
|
|
||||||
# 根据数据类型选择相应的API
|
|
||||||
if data_type == 'daily':
|
|
||||||
df = pro.daily(ts_code=stock_code, start_date=start_date, end_date=end_date)
|
|
||||||
elif data_type == 'weekly':
|
|
||||||
df = pro.weekly(ts_code=stock_code, start_date=start_date, end_date=end_date)
|
|
||||||
elif data_type == 'monthly':
|
|
||||||
df = pro.monthly(ts_code=stock_code, start_date=start_date, end_date=end_date)
|
|
||||||
elif data_type == 'money_flow':
|
|
||||||
df = pro.moneyflow(ts_code=stock_code, start_date=start_date, end_date=end_date)
|
|
||||||
elif data_type == 'daily_basic':
|
|
||||||
df = pro.daily_basic(ts_code=stock_code, start_date=start_date, end_date=end_date)
|
|
||||||
else:
|
|
||||||
print(f"不支持的数据类型: {data_type}")
|
|
||||||
return False
|
|
||||||
# 如果数据为空,返回
|
|
||||||
if df.empty:
|
|
||||||
print(f"没有找到 {stock_code} 从 {start_date} 到 {end_date} 的 {data_type} 数据")
|
|
||||||
return False
|
|
||||||
# 保存数据
|
|
||||||
df.to_csv(file_path, index=False)
|
|
||||||
print(f"数据已保存到 {file_path}")
|
|
||||||
return True
|
|
||||||
|
|
||||||
def get_stock_basic():
|
|
||||||
# 获取当前时间作为文件名的一部分
|
|
||||||
current_time = datetime.now().strftime('%Y%m%d')
|
|
||||||
|
|
||||||
# 创建保存目录
|
|
||||||
save_dir = 'data'
|
|
||||||
os.makedirs(save_dir, exist_ok=True)
|
|
||||||
|
|
||||||
# 生成文件名
|
|
||||||
file_name = f"stock_basic_{current_time}.csv"
|
|
||||||
file_path = os.path.join(save_dir, file_name)
|
|
||||||
|
|
||||||
# 下载股票基本信息,包含所有可用字段
|
|
||||||
fields = 'ts_code,symbol,name,area,industry,fullname,enname,cnspell,market,exchange,curr_type,list_status,list_date,delist_date,is_hs,act_name,act_ent_type'
|
|
||||||
df = pro.stock_basic(exchange='', list_status='L', fields=fields)
|
|
||||||
|
|
||||||
# 如果数据为空,返回
|
|
||||||
if df.empty:
|
|
||||||
print("没有找到股票基本信息数据")
|
|
||||||
return
|
|
||||||
|
|
||||||
# 保存数据
|
|
||||||
df.to_csv(file_path, index=False)
|
|
||||||
print(f"股票基本信息数据已保存到 {file_path}")
|
|
||||||
|
|
||||||
def load_share_data(stock_code, data_type='daily', start_date=None, end_date=None):
|
|
||||||
|
|
||||||
data_dir = os.path.join('data', stock_code)
|
|
||||||
|
|
||||||
# 如果未指定日期,尝试加载任何可用数据
|
|
||||||
if start_date is None or end_date is None:
|
|
||||||
files = [f for f in os.listdir(data_dir) if f.startswith(f"{data_type}_")]
|
|
||||||
if not files:
|
|
||||||
raise ValueError(f"找不到{stock_code}的{data_type}数据文件")
|
|
||||||
|
|
||||||
# 使用找到的第一个文件
|
|
||||||
file_path = os.path.join(data_dir, files[0])
|
|
||||||
else:
|
|
||||||
# 格式化日期
|
|
||||||
start_date = pd.to_datetime(start_date).strftime('%Y%m%d')
|
|
||||||
end_date = pd.to_datetime(end_date).strftime('%Y%m%d')
|
|
||||||
file_name = f"{data_type}_{start_date}_{end_date}.csv"
|
|
||||||
file_path = os.path.join(data_dir, file_name)
|
|
||||||
|
|
||||||
if not os.path.exists(file_path):
|
|
||||||
print(f"找不到{file_path}数据文件,执行下载")
|
|
||||||
get_trans_data(stock_code, start_date, end_date, data_type)
|
|
||||||
|
|
||||||
print(f"加载数据: {file_path}")
|
|
||||||
|
|
||||||
# 读取CSV文件
|
|
||||||
df = pd.read_csv(file_path)
|
|
||||||
|
|
||||||
# 确保日期列是datetime类型并按日期升序排序
|
|
||||||
df['trade_date'] = pd.to_datetime(df['trade_date'], format='%Y%m%d')
|
|
||||||
df = df.sort_values('trade_date')
|
|
||||||
|
|
||||||
# 重命名列以符合backtrader期望的格式
|
|
||||||
df = df.rename(columns={
|
|
||||||
'trade_date': 'datetime',
|
|
||||||
'open': 'open',
|
|
||||||
'high': 'high',
|
|
||||||
'low': 'low',
|
|
||||||
'close': 'close',
|
|
||||||
'vol': 'volume'
|
|
||||||
})
|
|
||||||
|
|
||||||
# 确保所有必需的列都存在
|
|
||||||
required_columns = ['datetime', 'open', 'high', 'low', 'close', 'volume']
|
|
||||||
for col in required_columns:
|
|
||||||
if col not in df.columns:
|
|
||||||
raise ValueError(f"缺少必需的列: {col}")
|
|
||||||
|
|
||||||
# 将datetime设为索引
|
|
||||||
df = df.set_index('datetime')
|
|
||||||
|
|
||||||
return df
|
|
||||||
|
|
||||||
def create_engine_from_config(config):
|
def create_engine_from_config(config):
|
||||||
mysql = config['mysql']
|
mysql = config['mysql']
|
||||||
connection_string = f"mysql+pymysql://{mysql['user']}:{mysql['password']}@{mysql['host']}:{mysql['port']}/{mysql['database']}?charset={mysql['charset']}&use_unicode=1"
|
connection_string = f"mysql+pymysql://{mysql['user']}:{mysql['password']}@{mysql['host']}:{mysql['port']}/{mysql['database']}?charset={mysql['charset']}&use_unicode=1"
|
||||||
return create_engine(connection_string)
|
return create_engine(connection_string)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
stock_codes = ['002112.SZ', '601801.SH', '002737.SZ', '600970.SH'] # 示例股票代码
|
|
||||||
start_date = '20230101' # 手动输入的开始日期
|
|
||||||
end_date = '20250403' # 手动输入的结束日期
|
|
||||||
data_types = ['daily', 'weekly', 'monthly', 'money_flow', 'daily_basic'] # 可选数据类型
|
|
||||||
|
|
||||||
# 下载股票基本信息
|
|
||||||
get_stock_basic()
|
|
||||||
|
|
||||||
# 下载其他数据
|
|
||||||
for stock_code in stock_codes:
|
|
||||||
for data_type in data_types:
|
|
||||||
get_trans_data(stock_code, start_date, end_date, data_type)
|
|
Loading…
Reference in New Issue
Block a user