CCPP/adv_training.py

186 lines
7.3 KiB
Python
Raw Normal View History

2025-04-20 20:55:06 +08:00
from models import ResNet, ConvNet
import torch.nn as nn
import argparse
from utils import UcrDataset, UCR_dataloader, AdvDataset
import torch.optim as optim
import torch.utils.data
import os
import random
import numpy as np
import matplotlib.pyplot as plt
parser = argparse.ArgumentParser()
parser.add_argument('--test', action='store_true', help='')
parser.add_argument('--query_one', action='store_true', help='query the probability of target idx sample')
parser.add_argument('--idx', type=int, help='the index of test sample ')
parser.add_argument('--gpu', type=str, default='0', help='the index of test sample ')
parser.add_argument('--channel_last', type=bool, default=True, help='the channel of data is last or not')
parser.add_argument('--n_class', type=int, default=3, help='the class number of dataset')
parser.add_argument('--epochs', type=int, default=1500, help='number of epochs to train for')
parser.add_argument('--e', default=1499, help='number of epochs to train for')
parser.add_argument('--lr', type=float, default=0.0002, help='learning rate, default=0.0002')
parser.add_argument('--cuda', action='store_false', help='enables cuda')
parser.add_argument('--checkpoints_folder', default='model_checkpoints', help='folder to save checkpoints')
parser.add_argument('--manualSeed', type=int, help='manual seed')
parser.add_argument('--run_tag', default='StarLightCurves', help='tags for the current run')
parser.add_argument('--model', default='f', help='the model type(ResNet,FCN)')
parser.add_argument('--normalize', action='store_true', help='')
parser.add_argument('--checkpoint_every', default=5, help='number of epochs after which saving checkpoints')
opt = parser.parse_args()
print(opt)
# configure cuda
if torch.cuda.is_available() and not opt.cuda:
os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu
print("You have a cuda device, so you might want to run with --cuda as option")
device = torch.device("cuda:0" if opt.cuda else "cpu")
if opt.manualSeed is None:
opt.manualSeed = random.randint(1, 10000)
print("Random Seed: ", opt.manualSeed)
random.seed(opt.manualSeed)
torch.manual_seed(opt.manualSeed)
def train(l,e):
os.makedirs(opt.checkpoints_folder, exist_ok=True)
os.makedirs('%s/%s' % (opt.checkpoints_folder, opt.run_tag), exist_ok=True)
#这里是获得的train数据
dataset_path = 'data/' + opt.run_tag + '/' + opt.run_tag + '_TRAIN.txt'
dataset = UcrDataset(dataset_path, channel_last=opt.channel_last, normalize=True)
#这里是result里面的攻击时间序列数据对抗训练
# attacked_data_path = 'final_result/' + opt.run_tag + '/' + 'attack_time_series.txt' #攻击的时间序列
# attacked_dataset = AdvDataset(txt_file=attacked_data_path)
#batch数值小于等于16这里是10
batch_size = int(min(len(dataset) / 10, 16))
#batch_size = 64
print('dataset length: ', len(dataset))
# print('number of adv examples', len(attacked_dataset))
print('batch size', batch_size)
#这里batch数值是10
dataloader = UCR_dataloader(dataset, batch_size)
# adv_dataloader = UCR_dataloader(attacked_dataset, batch_size)
seq_len = dataset.get_seq_len()
n_class = opt.n_class
print('sequence len:', seq_len)
if opt.model == 'r':
net = ResNet(n_in=seq_len, n_classes=n_class).to(device)
if opt.model == 'f':
net = ConvNet(n_in=seq_len, n_classes=n_class).to(device)
net.train()
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(net.parameters(), lr=opt.lr)
print('############# Start to Train ###############')
#opt.epochs=100(原来等于1499别忘了改回去
for epoch in range(opt.epochs):
for i, (data, label) in enumerate(dataloader):
if data.size(0) != batch_size:
break
data = data.float()
data = data.to(device)
label = label.long()
label = label.to(device)
optimizer.zero_grad()
output = net(data)
loss = criterion(output, label.view(label.size(0)))
loss.backward()
optimizer.step()
print('#######Train########')
print('[%d/%d][%d/%d] Loss: %.4f ' % (epoch, opt.epochs, i + 1, len(dataloader), loss.item()))
if (epoch % (opt.checkpoint_every * 10) == 0) or (epoch == (opt.epochs - 1)):
print('Saving the %dth epoch model.....' % epoch)
#训练完保存模型
torch.save(net, '%s/%s/pre_%s%depoch.pth' % (opt.checkpoints_folder, opt.run_tag, opt.model, epoch))
if (epoch % 10 == 0):
l.append(loss.item())
e.append(epoch)
def te():
data_path = 'data/' + opt.run_tag + '/' + opt.run_tag + '_TEST.txt'
dataset = UcrDataset(data_path, channel_last=opt.channel_last, normalize=opt.normalize)
batch_size = int(min(len(dataset) / 10, 16))
print('dataset length: ', len(dataset))
print('batch_size:', batch_size)
dataloader = UCR_dataloader(dataset, batch_size)
type = opt.model
model_path = 'model_checkpoints/' + opt.run_tag + '/pre_' + type + str(opt.e) + 'epoch.pth'
#model_path ='model_checkpoints/' + opt.run_tag + '/pre_fTrained.pth'
model = torch.load(model_path, map_location='cuda:0')
with torch.no_grad():
model.eval()
total = 0
correct = 0
for i, (data, label) in enumerate(dataloader):
data = data.float()
data = data.to(device)
label = label.long()
label = label.to(device)
label = label.view(label.size(0))
total += label.size(0)
out = model(data).cuda()
softmax = nn.Softmax(dim=-1)
prob = softmax(out)
pred_label = torch.argmax(prob, dim=1)
correct += (pred_label == label).sum().item()
print('The TEST Accuracy of %s is : %.2f %%' % (data_path, correct / total * 100))
def query_one(idx):
data_path = 'data/' + opt.run_tag + '/' + opt.run_tag + '_TEST.txt'
test_data = np.loadtxt(data_path)
test_data = torch.from_numpy(test_data)
test_one = test_data[idx]
X = test_one[1:].float()
X = X.to(device)
y = test_one[0].long() - 1
y = y.to(device)
if y < 0:
y = opt.n_class - 1
print('ground truth', y)
type = opt.model
model_path = 'model_checkpoints/' + opt.run_tag + '/pre_' + type + str(opt.e) + 'epoch.pth'
#model_path = 'model_checkpoints/' + opt.run_tag + '/pre_fTrained.pth'
model = torch.load(model_path, map_location='cpu')
model.eval()
out = model(X)
softmax = nn.Softmax(dim=-1)
prob_vector = softmax(out)
print('prob vector', prob_vector)
prob = prob_vector.view(opt.n_class)[y].item()
print('Confidence in true class of the %d sample is %.4f ' % (idx, prob))
def plot1(model,loss,epoch):
plt.title('The loss of traing '+opt.run_tag+'model', fontstyle='italic')
plt.figure(figsize=(6, 4))
plt.plot(epoch,loss, color='b', label='loss')
plt.xlabel('epoch', fontsize=12)
plt.legend(loc='upper right', fontsize=8)
#plt.savefig('loss' + model+'.png')
plt.show()
if __name__ == '__main__':
# if opt.test:
# ce()
# elif opt.query_one:
# query_one(opt.idx)
# else:
# train()
l = []
e = []
train(l,e)
# plot1(opt.model,l,e)
te()