diff --git a/test.py b/test.py deleted file mode 100644 index 713d7f5..0000000 --- a/test.py +++ /dev/null @@ -1,234 +0,0 @@ -import os -import pandas as pd -import backtrader as bt -import matplotlib.pyplot as plt -from datetime import datetime, timedelta - -from utils import load_share_data - - -# 自定义多头趋势策略 -class BullTrendStrategy(bt.Strategy): - params = ( - ('ma5_period', 5), # 短周期移动平均线 - ('ma10_period', 10), # 中周期移动平均线 - ('ma20_period', 20), # 长周期移动平均线 - ('min_lot', 100), # 最小交易手数(一手) - ('stop_loss_pct', 5.0), # 止损百分比,5% - ('max_gain_pct', 5.0), # 最大涨幅限制 - ) - - def __init__(self): - # 初始化移动平均线指标 - self.ma5 = bt.indicators.SMA(self.data.close, period=self.params.ma5_period) - self.ma10 = bt.indicators.SMA(self.data.close, period=self.params.ma10_period) - self.ma20 = bt.indicators.SMA(self.data.close, period=self.params.ma20_period) - - # 记录多头趋势信号 (MA5 > MA10 > MA20) - self.bull_trend = bt.indicators.And(self.ma5 > self.ma10, self.ma10 > self.ma20) - - # 记录MA10 > MA20的条件 - self.ma10_gt_ma20 = self.ma10 > self.ma20 - - # 订单和买入价格 - self.order = None - self.buy_price = None - - # 添加止损相关变量 - self.signal_low = None # 出现多头趋势信号时的最低价 - self.stop_loss_price = None # 止损价格 - - # 日志 - self.log( - f"策略初始化: MA5={self.params.ma5_period}, MA10={self.params.ma10_period}, MA20={self.params.ma20_period}, 止损={self.params.stop_loss_pct}%, 涨幅限制={self.params.max_gain_pct}%") - - def log(self, txt, dt=None): - """记录策略日志""" - dt = dt or self.datas[0].datetime.date(0) - print(f'{dt.isoformat()}, {txt}') - - def notify_order(self, order): - if order.status in [order.Submitted, order.Accepted]: - # 订单已提交/已接受,无需操作 - return - - # 检查订单是否已完成 - if order.status in [order.Completed]: - if order.isbuy(): - self.buy_price = order.executed.price - self.log( - f'买入执行: 价格={order.executed.price:.2f}, 数量={order.executed.size}, 成本={order.executed.value:.2f}, 手续费={order.executed.comm:.2f}') - else: # 卖出 - profit = (order.executed.price - self.buy_price) * order.executed.size - profit_pct = (order.executed.price - self.buy_price) / self.buy_price * 100 - self.log( - f'卖出执行: 价格={order.executed.price:.2f}, 数量={order.executed.size}, 盈亏={profit:.2f}元 ({profit_pct:.2f}%), 手续费={order.executed.comm:.2f}') - # 清除止损价格和信号最低价 - self.stop_loss_price = None - self.signal_low = None - self.buy_price = None - - elif order.status in [order.Canceled, order.Margin, order.Rejected]: - self.log(f'订单取消/拒绝/保证金不足: {order.status}') - - # 重置订单 - self.order = None - - def next(self): - # 如果有未完成的订单,不进行操作 - if self.order: - return - - # 检查是否已持仓 - if not self.position: - # 没有持仓 - 检查是否符合买入条件 - # 条件1:T日是多头趋势,且T-1日是MA10>MA20但MA5不满足的状态 - if (self.bull_trend[0] and # T日是多头趋势 - self.ma10_gt_ma20[-1] and # T-1日MA10>MA20 - not self.bull_trend[-1]): # T-1日不是完全多头趋势(即MA5不满足) - - # 条件2:T日为阳线(收盘价>开盘价) - is_bullish = self.data.close[0] > self.data.open[0] - - # 条件3:T日涨幅不超过5% - daily_gain_pct = (self.data.close[0] - self.data.close[-1]) / self.data.close[-1] * 100 - valid_gain = daily_gain_pct <= self.params.max_gain_pct - - # 所有条件满足才进行买入 - if is_bullish and valid_gain: - # 记录出现多头趋势信号时的最低价 - self.signal_low = self.data.low[0] - # 计算止损价格(比信号出现时最低价低5%) - self.stop_loss_price = self.signal_low * (1 - self.params.stop_loss_pct / 100) - - self.log(f'买入信号! T日多头趋势,T-1日MA10>MA20但MA5不满足') - self.log(f'当前MA5={self.ma5[0]:.2f}, MA10={self.ma10[0]:.2f}, MA20={self.ma20[0]:.2f}') - self.log(f'昨日MA5={self.ma5[-1]:.2f}, MA10={self.ma10[-1]:.2f}, MA20={self.ma20[-1]:.2f}') - self.log(f'K线为阳线,当日涨幅={daily_gain_pct:.2f}%,在限制{self.params.max_gain_pct}%以内') - self.log( - f'设置止损价格: {self.stop_loss_price:.2f} (信号最低价{self.signal_low:.2f}的{self.params.stop_loss_pct}%)') - - # 计算可以买入的股数(必须是100的整数倍) - available_cash = self.broker.getcash() * 0.95 # 保留5%现金 - price = self.data.close[0] - size = int(available_cash / price / self.params.min_lot) * self.params.min_lot - - if size >= self.params.min_lot: - self.log(f'设置购买订单,下一个开盘价,数量={size}股') - # 下一个bar开盘价买入 - self.order = self.buy(size=size, exectype=bt.Order.Market) - else: - self.log( - f'资金不足,无法购买最小手数: 需要{self.params.min_lot}股,当前资金只能买{int(available_cash / price)}股') - elif self.bull_trend[0]: - # 显示为什么不买入的原因 - if not is_bullish: - self.log( - f'符合多头趋势但不买入: K线不是阳线 (开盘={self.data.open[0]:.2f}, 收盘={self.data.close[0]:.2f})') - elif not valid_gain: - self.log( - f'符合多头趋势但不买入: 当日涨幅={daily_gain_pct:.2f}%,超过限制{self.params.max_gain_pct}%') - else: - # 已经持仓 - 检查是否应该卖出 - # 检查止损条件 - 如果当前最低价低于止损价格 - if self.data.low[0] <= self.stop_loss_price: - self.log(f'触发止损! 当前最低价={self.data.low[0]:.2f} 低于止损价={self.stop_loss_price:.2f}') - self.log(f'设置止损卖出订单,下一个开盘价') - # 下一个bar开盘价卖出全部持仓 - self.order = self.sell(size=self.position.size, exectype=bt.Order.Market) - - # 当T日不是多头趋势时,T+1日开盘卖出 - elif not self.bull_trend[0]: - self.log(f'卖出信号! 当前非多头趋势') - self.log(f'MA5={self.ma5[0]:.2f}, MA10={self.ma10[0]:.2f}, MA20={self.ma20[0]:.2f}') - self.log(f'设置卖出订单,下一个开盘价') - # 下一个bar开盘价卖出全部持仓 - self.order = self.sell(size=self.position.size, exectype=bt.Order.Market) - -def run_backtest(stock_code, start_date=None, end_date=None, ma5_period=5, ma10_period=10, ma20_period=20, - initial_cash=100000): - # 创建cerebro引擎 - cerebro = bt.Cerebro() - - # 添加策略 - cerebro.addstrategy(BullTrendStrategy, - ma5_period=ma5_period, - ma10_period=ma10_period, - ma20_period=ma20_period) - - # 加载数据 - df = load_share_data(stock_code, 'daily', start_date, end_date) - - # 创建数据源 - data = bt.feeds.PandasData(dataname=df) - - # 添加数据到cerebro - cerebro.adddata(data) - - # 设置初始资金 - cerebro.broker.setcash(initial_cash) - - # 设置手续费 (0.1%) - cerebro.broker.setcommission(commission=0.001) - - # 设置滑点 (0.1%) - cerebro.broker.set_slippage_perc(0.001) - - # 添加分析器 - cerebro.addanalyzer(bt.analyzers.SharpeRatio, _name='sharpe') - cerebro.addanalyzer(bt.analyzers.DrawDown, _name='drawdown') - cerebro.addanalyzer(bt.analyzers.Returns, _name='returns') - cerebro.addanalyzer(bt.analyzers.TradeAnalyzer, _name='trades') - - # 显示起始资金 - print(f'起始资金: {cerebro.broker.getvalue():.2f}') - - # 运行回测 - results = cerebro.run() - strategy = results[0] - - # 显示结束资金 - final_value = cerebro.broker.getvalue() - print(f'结束资金: {final_value:.2f}') - print(f'总收益率: {(final_value - initial_cash) / initial_cash * 100:.2f}%') - - # 显示分析结果 - print(f'夏普比率: {strategy.analyzers.sharpe.get_analysis()["sharperatio"]:.3f}') - print(f'最大回撤: {strategy.analyzers.drawdown.get_analysis()["max"]["drawdown"]:.2f}%') - print(f'年化收益率: {strategy.analyzers.returns.get_analysis()["rnorm100"]:.2f}%') - - # 交易统计 - trade_analysis = strategy.analyzers.trades.get_analysis() - if trade_analysis.get('total', {}).get('total', 0) > 0: - print(f'总交易次数: {trade_analysis["total"]["total"]}') - print(f'盈利交易: {trade_analysis.get("won", {}).get("total", 0)}') - print(f'亏损交易: {trade_analysis.get("lost", {}).get("total", 0)}') - if trade_analysis.get("won", {}).get("total", 0) > 0: - print(f'平均盈利: {trade_analysis["won"]["pnl"]["average"]:.2f}') - if trade_analysis.get("lost", {}).get("total", 0) > 0: - print(f'平均亏损: {trade_analysis["lost"]["pnl"]["average"]:.2f}') - - # 绘制结果 - cerebro.plot(style='candle', figsize=(20, 10), barup='red', bardown='green') - - -if __name__ == "__main__": - # 回测参数 - stock_code = '002737.SZ' # 指定股票代码 - start_date = '20200101' # 开始日期 - end_date = '20250403' # 结束日期 (请注意使用实际下载数据的日期范围) - ma5_period = 5 # MA5周期 - ma10_period = 10 # MA10周期 - ma20_period = 20 # MA20周期 - initial_cash = 100000 # 初始资金 - - # 运行回测 - run_backtest( - stock_code=stock_code, - start_date=start_date, - end_date=end_date, - ma5_period=ma5_period, - ma10_period=ma10_period, - ma20_period=ma20_period, - initial_cash=initial_cash - ) \ No newline at end of file diff --git a/utils.py b/utils.py index fb8dd46..0c33b9d 100644 --- a/utils.py +++ b/utils.py @@ -31,143 +31,7 @@ def load_config(): return config -config = load_config() -ts.set_token(config['tushare_token']) -pro = ts.pro_api() - - -def get_trans_data(stock_code, start_date, end_date, data_type='daily'): - # 确保日期格式正确 - start_date = pd.to_datetime(start_date).strftime('%Y%m%d') - end_date = pd.to_datetime(end_date).strftime('%Y%m%d') - # 创建保存目录 - save_dir = os.path.join('data', stock_code) - os.makedirs(save_dir, exist_ok=True) - # 生成文件名 - file_name = f"{data_type}_{start_date}_{end_date}.csv" - file_path = os.path.join(save_dir, file_name) - # 检查文件是否已存在 - if os.path.exists(file_path): - print(f"{file_path} 已存在,跳过下载") - return True - # 根据数据类型选择相应的API - if data_type == 'daily': - df = pro.daily(ts_code=stock_code, start_date=start_date, end_date=end_date) - elif data_type == 'weekly': - df = pro.weekly(ts_code=stock_code, start_date=start_date, end_date=end_date) - elif data_type == 'monthly': - df = pro.monthly(ts_code=stock_code, start_date=start_date, end_date=end_date) - elif data_type == 'money_flow': - df = pro.moneyflow(ts_code=stock_code, start_date=start_date, end_date=end_date) - elif data_type == 'daily_basic': - df = pro.daily_basic(ts_code=stock_code, start_date=start_date, end_date=end_date) - else: - print(f"不支持的数据类型: {data_type}") - return False - # 如果数据为空,返回 - if df.empty: - print(f"没有找到 {stock_code} 从 {start_date} 到 {end_date} 的 {data_type} 数据") - return False - # 保存数据 - df.to_csv(file_path, index=False) - print(f"数据已保存到 {file_path}") - return True - -def get_stock_basic(): - # 获取当前时间作为文件名的一部分 - current_time = datetime.now().strftime('%Y%m%d') - - # 创建保存目录 - save_dir = 'data' - os.makedirs(save_dir, exist_ok=True) - - # 生成文件名 - file_name = f"stock_basic_{current_time}.csv" - file_path = os.path.join(save_dir, file_name) - - # 下载股票基本信息,包含所有可用字段 - fields = 'ts_code,symbol,name,area,industry,fullname,enname,cnspell,market,exchange,curr_type,list_status,list_date,delist_date,is_hs,act_name,act_ent_type' - df = pro.stock_basic(exchange='', list_status='L', fields=fields) - - # 如果数据为空,返回 - if df.empty: - print("没有找到股票基本信息数据") - return - - # 保存数据 - df.to_csv(file_path, index=False) - print(f"股票基本信息数据已保存到 {file_path}") - -def load_share_data(stock_code, data_type='daily', start_date=None, end_date=None): - - data_dir = os.path.join('data', stock_code) - - # 如果未指定日期,尝试加载任何可用数据 - if start_date is None or end_date is None: - files = [f for f in os.listdir(data_dir) if f.startswith(f"{data_type}_")] - if not files: - raise ValueError(f"找不到{stock_code}的{data_type}数据文件") - - # 使用找到的第一个文件 - file_path = os.path.join(data_dir, files[0]) - else: - # 格式化日期 - start_date = pd.to_datetime(start_date).strftime('%Y%m%d') - end_date = pd.to_datetime(end_date).strftime('%Y%m%d') - file_name = f"{data_type}_{start_date}_{end_date}.csv" - file_path = os.path.join(data_dir, file_name) - - if not os.path.exists(file_path): - print(f"找不到{file_path}数据文件,执行下载") - get_trans_data(stock_code, start_date, end_date, data_type) - - print(f"加载数据: {file_path}") - - # 读取CSV文件 - df = pd.read_csv(file_path) - - # 确保日期列是datetime类型并按日期升序排序 - df['trade_date'] = pd.to_datetime(df['trade_date'], format='%Y%m%d') - df = df.sort_values('trade_date') - - # 重命名列以符合backtrader期望的格式 - df = df.rename(columns={ - 'trade_date': 'datetime', - 'open': 'open', - 'high': 'high', - 'low': 'low', - 'close': 'close', - 'vol': 'volume' - }) - - # 确保所有必需的列都存在 - required_columns = ['datetime', 'open', 'high', 'low', 'close', 'volume'] - for col in required_columns: - if col not in df.columns: - raise ValueError(f"缺少必需的列: {col}") - - # 将datetime设为索引 - df = df.set_index('datetime') - - return df - def create_engine_from_config(config): mysql = config['mysql'] connection_string = f"mysql+pymysql://{mysql['user']}:{mysql['password']}@{mysql['host']}:{mysql['port']}/{mysql['database']}?charset={mysql['charset']}&use_unicode=1" return create_engine(connection_string) - - - -if __name__ == "__main__": - stock_codes = ['002112.SZ', '601801.SH', '002737.SZ', '600970.SH'] # 示例股票代码 - start_date = '20230101' # 手动输入的开始日期 - end_date = '20250403' # 手动输入的结束日期 - data_types = ['daily', 'weekly', 'monthly', 'money_flow', 'daily_basic'] # 可选数据类型 - - # 下载股票基本信息 - get_stock_basic() - - # 下载其他数据 - for stock_code in stock_codes: - for data_type in data_types: - get_trans_data(stock_code, start_date, end_date, data_type) \ No newline at end of file