CCPP/adv_training.py
2025-04-20 20:55:06 +08:00

186 lines
7.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from models import ResNet, ConvNet
import torch.nn as nn
import argparse
from utils import UcrDataset, UCR_dataloader, AdvDataset
import torch.optim as optim
import torch.utils.data
import os
import random
import numpy as np
import matplotlib.pyplot as plt
parser = argparse.ArgumentParser()
parser.add_argument('--test', action='store_true', help='')
parser.add_argument('--query_one', action='store_true', help='query the probability of target idx sample')
parser.add_argument('--idx', type=int, help='the index of test sample ')
parser.add_argument('--gpu', type=str, default='0', help='the index of test sample ')
parser.add_argument('--channel_last', type=bool, default=True, help='the channel of data is last or not')
parser.add_argument('--n_class', type=int, default=3, help='the class number of dataset')
parser.add_argument('--epochs', type=int, default=1500, help='number of epochs to train for')
parser.add_argument('--e', default=1499, help='number of epochs to train for')
parser.add_argument('--lr', type=float, default=0.0002, help='learning rate, default=0.0002')
parser.add_argument('--cuda', action='store_false', help='enables cuda')
parser.add_argument('--checkpoints_folder', default='model_checkpoints', help='folder to save checkpoints')
parser.add_argument('--manualSeed', type=int, help='manual seed')
parser.add_argument('--run_tag', default='StarLightCurves', help='tags for the current run')
parser.add_argument('--model', default='f', help='the model type(ResNet,FCN)')
parser.add_argument('--normalize', action='store_true', help='')
parser.add_argument('--checkpoint_every', default=5, help='number of epochs after which saving checkpoints')
opt = parser.parse_args()
print(opt)
# configure cuda
if torch.cuda.is_available() and not opt.cuda:
os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu
print("You have a cuda device, so you might want to run with --cuda as option")
device = torch.device("cuda:0" if opt.cuda else "cpu")
if opt.manualSeed is None:
opt.manualSeed = random.randint(1, 10000)
print("Random Seed: ", opt.manualSeed)
random.seed(opt.manualSeed)
torch.manual_seed(opt.manualSeed)
def train(l,e):
os.makedirs(opt.checkpoints_folder, exist_ok=True)
os.makedirs('%s/%s' % (opt.checkpoints_folder, opt.run_tag), exist_ok=True)
#这里是获得的train数据
dataset_path = 'data/' + opt.run_tag + '/' + opt.run_tag + '_TRAIN.txt'
dataset = UcrDataset(dataset_path, channel_last=opt.channel_last, normalize=True)
#这里是result里面的攻击时间序列数据对抗训练
# attacked_data_path = 'final_result/' + opt.run_tag + '/' + 'attack_time_series.txt' #攻击的时间序列
# attacked_dataset = AdvDataset(txt_file=attacked_data_path)
#batch数值小于等于16这里是10
batch_size = int(min(len(dataset) / 10, 16))
#batch_size = 64
print('dataset length: ', len(dataset))
# print('number of adv examples', len(attacked_dataset))
print('batch size', batch_size)
#这里batch数值是10
dataloader = UCR_dataloader(dataset, batch_size)
# adv_dataloader = UCR_dataloader(attacked_dataset, batch_size)
seq_len = dataset.get_seq_len()
n_class = opt.n_class
print('sequence len:', seq_len)
if opt.model == 'r':
net = ResNet(n_in=seq_len, n_classes=n_class).to(device)
if opt.model == 'f':
net = ConvNet(n_in=seq_len, n_classes=n_class).to(device)
net.train()
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(net.parameters(), lr=opt.lr)
print('############# Start to Train ###############')
#opt.epochs=100(原来等于1499别忘了改回去
for epoch in range(opt.epochs):
for i, (data, label) in enumerate(dataloader):
if data.size(0) != batch_size:
break
data = data.float()
data = data.to(device)
label = label.long()
label = label.to(device)
optimizer.zero_grad()
output = net(data)
loss = criterion(output, label.view(label.size(0)))
loss.backward()
optimizer.step()
print('#######Train########')
print('[%d/%d][%d/%d] Loss: %.4f ' % (epoch, opt.epochs, i + 1, len(dataloader), loss.item()))
if (epoch % (opt.checkpoint_every * 10) == 0) or (epoch == (opt.epochs - 1)):
print('Saving the %dth epoch model.....' % epoch)
#训练完保存模型
torch.save(net, '%s/%s/pre_%s%depoch.pth' % (opt.checkpoints_folder, opt.run_tag, opt.model, epoch))
if (epoch % 10 == 0):
l.append(loss.item())
e.append(epoch)
def te():
data_path = 'data/' + opt.run_tag + '/' + opt.run_tag + '_TEST.txt'
dataset = UcrDataset(data_path, channel_last=opt.channel_last, normalize=opt.normalize)
batch_size = int(min(len(dataset) / 10, 16))
print('dataset length: ', len(dataset))
print('batch_size:', batch_size)
dataloader = UCR_dataloader(dataset, batch_size)
type = opt.model
model_path = 'model_checkpoints/' + opt.run_tag + '/pre_' + type + str(opt.e) + 'epoch.pth'
#model_path ='model_checkpoints/' + opt.run_tag + '/pre_fTrained.pth'
model = torch.load(model_path, map_location='cuda:0')
with torch.no_grad():
model.eval()
total = 0
correct = 0
for i, (data, label) in enumerate(dataloader):
data = data.float()
data = data.to(device)
label = label.long()
label = label.to(device)
label = label.view(label.size(0))
total += label.size(0)
out = model(data).cuda()
softmax = nn.Softmax(dim=-1)
prob = softmax(out)
pred_label = torch.argmax(prob, dim=1)
correct += (pred_label == label).sum().item()
print('The TEST Accuracy of %s is : %.2f %%' % (data_path, correct / total * 100))
def query_one(idx):
data_path = 'data/' + opt.run_tag + '/' + opt.run_tag + '_TEST.txt'
test_data = np.loadtxt(data_path)
test_data = torch.from_numpy(test_data)
test_one = test_data[idx]
X = test_one[1:].float()
X = X.to(device)
y = test_one[0].long() - 1
y = y.to(device)
if y < 0:
y = opt.n_class - 1
print('ground truth', y)
type = opt.model
model_path = 'model_checkpoints/' + opt.run_tag + '/pre_' + type + str(opt.e) + 'epoch.pth'
#model_path = 'model_checkpoints/' + opt.run_tag + '/pre_fTrained.pth'
model = torch.load(model_path, map_location='cpu')
model.eval()
out = model(X)
softmax = nn.Softmax(dim=-1)
prob_vector = softmax(out)
print('prob vector', prob_vector)
prob = prob_vector.view(opt.n_class)[y].item()
print('Confidence in true class of the %d sample is %.4f ' % (idx, prob))
def plot1(model,loss,epoch):
plt.title('The loss of traing '+opt.run_tag+'model', fontstyle='italic')
plt.figure(figsize=(6, 4))
plt.plot(epoch,loss, color='b', label='loss')
plt.xlabel('epoch', fontsize=12)
plt.legend(loc='upper right', fontsize=8)
#plt.savefig('loss' + model+'.png')
plt.show()
if __name__ == '__main__':
# if opt.test:
# ce()
# elif opt.query_one:
# query_one(opt.idx)
# else:
# train()
l = []
e = []
train(l,e)
# plot1(opt.model,l,e)
te()