from __future__ import print_function
import argparse
import numpy as np
import os
import csv
import math
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torch.utils.data as Data
from torch.optim import SGD, lr_scheduler
from collections import OrderedDict
# Checkpoint related
START_EPOCH = 0
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
'''
Usage:
python
'''
class SmallCNN(nn.Module):
def __init__(self, drop=0.5):
super(SmallCNN, self).__init__()
self.num_channels = 1
self.num_labels = 10
activ = nn.ReLU(True)
self.feature_extractor = nn.Sequential(OrderedDict([
('conv1', nn.Conv2d(self.num_channels, 32, 3)),
('relu1', activ),
('conv2', nn.Conv2d(32, 32, 3)),
('relu2', activ),
('maxpool1', nn.MaxPool2d(2, 2)),
('conv3', nn.Conv2d(32, 64, 3)),
('relu3', activ),
('conv4', nn.Conv2d(64, 64, 3)),
('relu4', activ),
('maxpool2', nn.MaxPool2d(2, 2)),
]))
self.classifier = nn.Sequential(OrderedDict([
('fc1', nn.Linear(64 * 4 * 4, 200)),
('relu1', activ),
('drop', nn.Dropout(drop)),
('fc2', nn.Linear(200, 200)),
('relu2', activ),
('fc3', nn.Linear(200, self.num_labels)),
]))
for m in self.modules():
if isinstance(m, (nn.Conv2d)):
nn.init.kaiming_normal_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
nn.init.constant_(self.classifier.fc3.weight, 0)
nn.init.constant_(self.classifier.fc3.bias, 0)
def forward(self, input, with_latent=False, fake_relu=False, no_relu=False):
features = self.feature_extractor(input)
logits = self.classifier(features.view(-1, 64 * 4 * 4))
return logits
class AttackPGD(nn.Module):
"""Adversarial training with PGD.
"""
def __init__(self, model, config):
super(AttackPGD, self).__init__()
self.model = model
self.rand = config['random_start']
self.step_size = config['step_size']
self.epsilon = config['epsilon']
self.num_steps = config['num_steps']
def forward(self, inputs, target, make_adv=False):
x = inputs.detach()
if make_adv:
#step = LinfStep(self.epsilon, self.step_size)
if self.rand:
x = x + torch.zeros_like(x).uniform_(-self.epsilon, self.epsilon)
prev_training = bool(self.training)
self.eval()
for i in range(self.num_steps):
self.eval()
x = x.clone().detach().requires_grad_(True)
outputs = self.model(normalize(x).to(device))
losses = criterion(outputs, target)
loss = torch.mean(losses)
grad, = torch.autograd.grad(loss, [x])
with torch.no_grad():
step = torch.sign(grad) * self.step_size
diff = x + step - inputs
diff = torch.clamp(diff, -self.epsilon, self.epsilon)
x = torch.clamp(diff + inputs, 0, 1)
output = self.model(normalize(x.clone().detach()).to(device))
if prev_training:
self.train()
return output, x
def train_glist(epoch):
criterion = nn.CrossEntropyLoss()
train_loss, correct, total = [[0]*len(netlist) for _ in range(3)]
for idn in range(len(netlist)):
print('\nEpoch: %d' % epoch)
netlist[idn].train()
for batch_idx, (inputs, targets) in enumerate(train_loader):
inputs, targets = inputs.to(device), targets.to(device)
# Forward pass
outputs, final_inp = netlist[idn](inputs, targets, make_adv=True)
loss = criterion(outputs, targets).mean()
if len(loss.shape) > 0: loss = loss.mean()
# Backward and optimize
optimizerlist[idn].zero_grad()
loss.backward()
optimizerlist[idn].step()
with torch.no_grad():
train_loss[idn] += loss.item()
_, pred_idx = torch.max(outputs.data, 1)
total[idn] += targets.size(0)
correct[idn] += pred_idx.eq(targets.data).cpu().sum().float()
print(batch_idx, len(train_loader),
'Loss: %.3f | Acc: %.3f%% (%d/%d)'
% (train_loss[idn]/(batch_idx+1), 100.*correct[idn]/total[idn], correct[idn], total[idn]))
schedulelist[idn].step()
return [1./(batch_idx+1)*t for t in train_loss], [100./total[i]*correct[i] for i in range(len(correct))]
def pgdattack(model, inputs, targets, epsilon=8 / 255., step_size=2.0 / 255, num_steps=7, rand = False):
x = inputs.detach()
prev_training = bool(model.training)
model.eval()
for i in range(num_steps):
x = x.clone().detach().requires_grad_(True)
outputs = model.model(normalize(x))
losses = criterion(outputs, targets)
loss = torch.mean(losses)
grad, = torch.autograd.grad(loss, [x])
with torch.no_grad():
step = torch.sign(grad) * step_size
diff = x + step - inputs
diff = torch.clamp(diff, -epsilon, epsilon)
x = torch.clamp(diff + inputs, 0, 1)
output = model.model(normalize(x.clone().detach()))
if prev_training:
model.train()
return output, x
def testlist(epoch, idn):
criterion = nn.CrossEntropyLoss()
netlist[idn].eval()
test_loss = 0
correct = 0
total = 0
for batch_idx, (inputs, targets) in enumerate(test_loader):
inputs, targets = inputs.to(device), targets.to(device)
outputs, final_inp = pgdattack(netlist[idn], inputs, targets)
loss = criterion(outputs, targets)
test_loss += loss.item()
_, pred_idx = torch.max(outputs.data, 1)
total += targets.size(0)
correct += pred_idx.eq(targets.data).cpu().sum().float()
print(batch_idx, len(test_loader),
'Loss: %.3f | Acc: %.3f%% (%d/%d)'
% (test_loss/(batch_idx+1), 100.*correct/total, correct, total))
return test_loss/batch_idx, 100.*correct/total
def load_model(arg_model, model_path, arg_dict=1):
model = globals()[arg_model]().to(device)
import dill
checkpoint = torch.load(model_path, map_location=torch.device('cpu'), pickle_module=dill)
print(checkpoint.keys())
#print(checkpoint['model'].keys())
state_dict_path = 'model'
if not ('model' in checkpoint):
state_dict_path = 'state_dict'
if ('net' in checkpoint):
checkpoint['model'] = checkpoint['net']
del checkpoint['net']
if arg_dict:
if 'model' in checkpoint:
if hasattr(checkpoint['model'], 'state_dict'):
print("Hi ehsan*******************")
sd = checkpoint['model'].state_dict()
else:
sd = checkpoint['model']
elif 'state_dict' in checkpoint:
sd = checkpoint['state_dict']
print ('epoch', checkpoint['epoch'],
'arch', checkpoint['arch'],
'nat_prec1', checkpoint['nat_prec1'],
'adv_prec1', checkpoint['adv_prec1'])
else:
sd = checkpoint
print(sd.keys())
sd = {k.replace('module.attacker.model.', '').replace('module.model.','').replace('module.','').replace('model.',''):v for k,v in sd.items()}
keys = model.state_dict().keys()
new_state = {}
for k in sd.keys():
if k in keys:
new_state[k] = sd[k]
else:
print(k)
model.load_state_dict(new_state)
else:
model = checkpoint['model']
checkpoint = None
sd = None
model.eval().to(device)
return model
if __name__ == '__main__':
# Data
print('=====> Preparing data...')
transform_test = transforms.Compose([transforms.ToTensor(),])
trainset = torchvision.datasets.MNIST(root='../data', train=True, download=True, transform=transform_test)
testset = torchvision.datasets.MNIST(root='../data', train=False, download=True, transform=transform_test)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=False, num_workers=2)
test_loader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)
unnormalize = lambda x: x
normalize = lambda x: x
#Configuration for Attack Model
config = {
'epsilon': 0.3,
'step_size': 0.01,
'random_start': False,
'loss_func': 'xent',
'num_steps': 40
}
#Building the Models
print('=====> Building model...')
model_path=['checkpoint.pt.best']
source_models=['SmallCNN']
netlist = []
for i, arg_model in enumerate(source_models):
net = load_model(source_models[i], model_path[i])
model0 = AttackPGD(net, config)
model0 = model0.to(device)
netlist += [model0]
if torch.cuda.device_count() > 1:
print("=====> Use", torch.cuda.device_count(), "GPUs")
for idn in range(len(netlist)):
netlist[idn] = nn.DataParallel(netlist[idn])
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
# Hard-coded base parameters
train_args = {
"lr": 0.1,
"weight_decay": 1e-06,
"momentum": 0.9,
"step_lr": 50,
"epoch": 200}
optimizerlist, schedulelist = [], []
best_acc = 0
for idn in range(len(netlist)):
# Make optimizer
param_list = netlist[idn].parameters()
optimizer = SGD(param_list, train_args["lr"], train_args["momentum"],
weight_decay=train_args["weight_decay"])
schedule = lr_scheduler.StepLR(optimizer, step_size=train_args["step_lr"])
optimizerlist += [optimizer]
schedulelist += [schedule]
for epoch in range(START_EPOCH, train_args["epoch"]):
print("Epoch: ", epoch)
train_loss, train_acc = train_glist(epoch)
test_loss, test_acc = testlist(epoch, 0)