from models import ResNet, ConvNet import torch.nn as nn import argparse from utils import UcrDataset, UCR_dataloader, AdvDataset import torch.optim as optim import torch.utils.data import os import random import numpy as np import matplotlib.pyplot as plt parser = argparse.ArgumentParser() parser.add_argument('--test', action='store_true', help='') parser.add_argument('--query_one', action='store_true', help='query the probability of target idx sample') parser.add_argument('--idx', type=int, help='the index of test sample ') parser.add_argument('--gpu', type=str, default='0', help='the index of test sample ') parser.add_argument('--channel_last', type=bool, default=True, help='the channel of data is last or not') parser.add_argument('--n_class', type=int, default=3, help='the class number of dataset') parser.add_argument('--epochs', type=int, default=1500, help='number of epochs to train for') parser.add_argument('--e', default=1499, help='number of epochs to train for') parser.add_argument('--lr', type=float, default=0.0002, help='learning rate, default=0.0002') parser.add_argument('--cuda', action='store_false', help='enables cuda') parser.add_argument('--checkpoints_folder', default='model_checkpoints', help='folder to save checkpoints') parser.add_argument('--manualSeed', type=int, help='manual seed') parser.add_argument('--run_tag', default='StarLightCurves', help='tags for the current run') parser.add_argument('--model', default='f', help='the model type(ResNet,FCN)') parser.add_argument('--normalize', action='store_true', help='') parser.add_argument('--checkpoint_every', default=5, help='number of epochs after which saving checkpoints') opt = parser.parse_args() print(opt) # configure cuda if torch.cuda.is_available() and not opt.cuda: os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu print("You have a cuda device, so you might want to run with --cuda as option") device = torch.device("cuda:0" if opt.cuda else "cpu") if opt.manualSeed is None: opt.manualSeed = random.randint(1, 10000) print("Random Seed: ", opt.manualSeed) random.seed(opt.manualSeed) torch.manual_seed(opt.manualSeed) def train(l,e): os.makedirs(opt.checkpoints_folder, exist_ok=True) os.makedirs('%s/%s' % (opt.checkpoints_folder, opt.run_tag), exist_ok=True) #这里是获得的train数据 dataset_path = 'data/' + opt.run_tag + '/' + opt.run_tag + '_TRAIN.txt' dataset = UcrDataset(dataset_path, channel_last=opt.channel_last, normalize=True) #这里是result里面的攻击时间序列数据,对抗训练 # attacked_data_path = 'final_result/' + opt.run_tag + '/' + 'attack_time_series.txt' #攻击的时间序列 # attacked_dataset = AdvDataset(txt_file=attacked_data_path) #batch数值小于等于16(这里是10) batch_size = int(min(len(dataset) / 10, 16)) #batch_size = 64 print('dataset length: ', len(dataset)) # print('number of adv examples:', len(attacked_dataset)) print('batch size:', batch_size) #这里batch数值是10 dataloader = UCR_dataloader(dataset, batch_size) # adv_dataloader = UCR_dataloader(attacked_dataset, batch_size) seq_len = dataset.get_seq_len() n_class = opt.n_class print('sequence len:', seq_len) if opt.model == 'r': net = ResNet(n_in=seq_len, n_classes=n_class).to(device) if opt.model == 'f': net = ConvNet(n_in=seq_len, n_classes=n_class).to(device) net.train() criterion = nn.CrossEntropyLoss().to(device) optimizer = optim.Adam(net.parameters(), lr=opt.lr) print('#############!! Start to Train!! ###############') #opt.epochs=100(原来等于1499,别忘了改回去) for epoch in range(opt.epochs): for i, (data, label) in enumerate(dataloader): if data.size(0) != batch_size: break data = data.float() data = data.to(device) label = label.long() label = label.to(device) optimizer.zero_grad() output = net(data) loss = criterion(output, label.view(label.size(0))) loss.backward() optimizer.step() print('#######!!!Train!!!########') print('[%d/%d][%d/%d] Loss: %.4f ' % (epoch, opt.epochs, i + 1, len(dataloader), loss.item())) if (epoch % (opt.checkpoint_every * 10) == 0) or (epoch == (opt.epochs - 1)): print('Saving the %dth epoch model.....' % epoch) #训练完保存模型 torch.save(net, '%s/%s/pre_%s%depoch.pth' % (opt.checkpoints_folder, opt.run_tag, opt.model, epoch)) if (epoch % 10 == 0): l.append(loss.item()) e.append(epoch) def te(): data_path = 'data/' + opt.run_tag + '/' + opt.run_tag + '_TEST.txt' dataset = UcrDataset(data_path, channel_last=opt.channel_last, normalize=opt.normalize) batch_size = int(min(len(dataset) / 10, 16)) print('dataset length: ', len(dataset)) print('batch_size:', batch_size) dataloader = UCR_dataloader(dataset, batch_size) type = opt.model model_path = 'model_checkpoints/' + opt.run_tag + '/pre_' + type + str(opt.e) + 'epoch.pth' #model_path ='model_checkpoints/' + opt.run_tag + '/pre_fTrained.pth' model = torch.load(model_path, map_location='cuda:0') with torch.no_grad(): model.eval() total = 0 correct = 0 for i, (data, label) in enumerate(dataloader): data = data.float() data = data.to(device) label = label.long() label = label.to(device) label = label.view(label.size(0)) total += label.size(0) out = model(data).cuda() softmax = nn.Softmax(dim=-1) prob = softmax(out) pred_label = torch.argmax(prob, dim=1) correct += (pred_label == label).sum().item() print('The TEST Accuracy of %s is : %.2f %%' % (data_path, correct / total * 100)) def query_one(idx): data_path = 'data/' + opt.run_tag + '/' + opt.run_tag + '_TEST.txt' test_data = np.loadtxt(data_path) test_data = torch.from_numpy(test_data) test_one = test_data[idx] X = test_one[1:].float() X = X.to(device) y = test_one[0].long() - 1 y = y.to(device) if y < 0: y = opt.n_class - 1 print('ground truth:', y) type = opt.model model_path = 'model_checkpoints/' + opt.run_tag + '/pre_' + type + str(opt.e) + 'epoch.pth' #model_path = 'model_checkpoints/' + opt.run_tag + '/pre_fTrained.pth' model = torch.load(model_path, map_location='cpu') model.eval() out = model(X) softmax = nn.Softmax(dim=-1) prob_vector = softmax(out) print('prob vector:', prob_vector) prob = prob_vector.view(opt.n_class)[y].item() print('Confidence in true class of the %d sample is %.4f ' % (idx, prob)) def plot1(model,loss,epoch): plt.title('The loss of traing '+opt.run_tag+'model', fontstyle='italic') plt.figure(figsize=(6, 4)) plt.plot(epoch,loss, color='b', label='loss') plt.xlabel('epoch', fontsize=12) plt.legend(loc='upper right', fontsize=8) #plt.savefig('loss' + model+'.png') plt.show() if __name__ == '__main__': # if opt.test: # ce() # elif opt.query_one: # query_one(opt.idx) # else: # train() l = [] e = [] train(l,e) # plot1(opt.model,l,e) te()