I have this code for adversarial training of Cifar10 with Resnet18 arch. While the training accuracy is high, the natural and adversarial accuracy remained pretty low.
def _pgd_whitebox(model,
X,
y,
epsilon=args.epsilon,
num_steps=args.num_steps,
step_size=args.step_size,
sum_dir = None):
#out = model(X)
#err = (out.data.max(1)[1] != y.data).float().sum()
X_pgd = Variable(X.data, requires_grad=True)
print("eps", epsilon, "step_size", step_size)
if args.random:
random_noise = torch.FloatTensor(*X_pgd.shape).uniform_(-epsilon, epsilon).to(device)
X_pgd = Variable(X_pgd.data + random_noise, requires_grad=True)
for i in range(num_steps):
opt = optim.SGD([X_pgd], lr=1e-3)
opt.zero_grad()
with torch.enable_grad():
loss = criterion(model(X_pgd), y)
loss.backward()
eta = step_size * X_pgd.grad.data.sign()
X_pgd = Variable(X_pgd.data + eta, requires_grad=True)
eta = torch.clamp(X_pgd.data - X.data, -epsilon, epsilon)
X_pgd = Variable(X.data + eta, requires_grad=True)
X_pgd = Variable(torch.clamp(X_pgd, 0, 1.0), requires_grad=True)
return X_pgd.detach()
def test_nominal(epoch):
net.eval()
test_loss = 0
correct = 0
total = 0
with torch.no_grad():
for batch_idx, (inputs, targets) in enumerate(testloader):
inputs, targets = inputs.to(device), targets.to(device)
outputs = net(inputs)
loss = criterion(outputs, targets)
test_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
% (test_loss/(batch_idx+1), 100.*correct/total, correct, total))
# Save checkpoint.
acc = 100.*correct/total
# Training
def train(epoch):
print('\nEpoch: %d' % epoch)
net.train()
train_loss = 0
correct = 0
nominal_correct = 0
total = 0
for batch_idx, (inputs, targets) in enumerate(trainloader):
inputs, targets = inputs.to(device), targets.to(device)
inputs_pgd = _pgd_whitebox(net, inputs, targets)
optimizer.zero_grad()
outputs = net(inputs_pgd)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
with torch.no_grad():
outputs_nominal = net(inputs)
_, predicted_nominal = outputs_nominal.max(1)
nominal_correct += predicted_nominal.eq(targets).sum().item()
train_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Adv Acc: %.3f%% (%d/%d) | Acc: %.3f%% (%d/%d)%%'
% (train_loss/(batch_idx+1), 100.*correct/total, correct, total,
100.*nominal_correct/total, nominal_correct, total))
def test(epoch):
net.eval()
test_loss = 0
correct = 0
nominal_correct = 0
total = 0
for batch_idx, (inputs, targets) in enumerate(testloader):
inputs, targets = inputs.to(device), targets.to(device)
inputs_pgd = _pgd_whitebox(net, inputs, targets)
with torch.no_grad():
outputs = net(inputs_pgd)
loss = criterion(outputs, targets)
outputs_nominal = net(inputs)
_, predicted_nominal = outputs_nominal.max(1)
nominal_correct += predicted_nominal.eq(targets).sum().item()
test_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Adv Acc: %.3f%% (%d/%d) | Acc: %.3f%% (%d/%d) %%'
% (test_loss/(batch_idx+1), 100.*correct/total, correct, total,
100.*nominal_correct/total, nominal_correct, total))
# Save checkpoint.
acc = 100.*correct/total
print('Saving..')
state = {
'net': net.state_dict(),
'acc': acc,
'epoch': epoch,
}
if not os.path.isdir(checkpoint_dir):
os.mkdir(checkpoint_dir)
torch.save(state, checkpoint_file.format(epoch))
for epoch in range(start_epoch, start_epoch+200):
train(epoch)
test(epoch)