Thanks for the reply. I think this is the way. I know decided to make a class out of it and keep two copies of the model, which are updated every time the loss is called. If you are interested, this is how I did it:
Paste the following into torch_loss.py
import torch
from copy import deepcopy
class AdversarialLoss():
def __init__(
self,
model,
natural_loss,
robustness_loss,
device,
n_attack_steps,
mismatch_level,
initial_std,
beta_robustness
):
# - Make copies of the models
self.model_theta = deepcopy(model)
self.model_theta_star = deepcopy(model)
self.natural_loss = natural_loss
self.robustness_loss = robustness_loss
self.device = device
self.n_attack_steps = n_attack_steps
self.mismatch_level = mismatch_level
self.initial_std = initial_std
self.beta_robustness = beta_robustness
def L_rob(
self,
output_theta,
output_theta_star
):
return self.robustness_loss(
torch.nn.functional.softmax(output_theta_star, dim=1).log(),
torch.nn.functional.softmax(output_theta, dim=1)
)
def _adversarial_loss(
self,
model,
X
):
# - Update the parameters of the "healthy" model
self.model_theta.load_state_dict(model.state_dict())
# - Initialize theta* with small gaussian noise
with torch.no_grad():
# - f(X,theta)
output_theta = self.model_theta(X)
# - Accumulate the signed gradients for the gradient calculation
sum_signed_grads = {}
# - Compute theta*
theta_star = {}
# - Step size is scaled to each parameter and determines how much the adversary can effect the parameter
step_size = {}
# - Store random vals for the gradient computation
random_val_dict = {}
for name, v in self.model_theta.named_parameters():
sum_signed_grads[name] = torch.zeros_like(v, device=self.device)
# print("!! WARNING Using torch.ones_like as random initial pert.")
random_val = torch.randn(size=v.shape, device=self.device)
# random_val = torch.ones_like(v, device=self.device)
random_val_dict[name] = random_val
theta_star[name] = v + v.abs() * self.initial_std * random_val
step_size[name] = (self.mismatch_level * v.abs()) / self.n_attack_steps
# - PGA attack
for _ in range(self.n_attack_steps):
# - Load the initial theta_star
self.model_theta_star.load_state_dict(theta_star)
# - Pass input through net with adv. parameters and compute grad of robustness loss
output_theta_star = self.model_theta_star(X)
step_loss = self.L_rob(output_theta=output_theta, output_theta_star=output_theta_star)
step_loss.backward()
# - Update the sum of the signed gradients
for name,v in self.model_theta_star.named_parameters():
sum_signed_grads[name] += v.grad.sign()
# - Update theta*
for name,v in self.model_theta_star.named_parameters():
theta_star[name] = theta_star[name] + step_size[name] * v.grad.sign()
v.grad = None # - Ensure gradients don't accumulate
# - After updating theta_star, load the new weights into the network
self.model_theta_star.load_state_dict(theta_star)
# - Calculate d L_rob / d theta* for computing the final gradient
output_theta_star = self.model_theta_star(X)
loss_rob = self.L_rob(output_theta=output_theta, output_theta_star=output_theta_star)
loss_rob.backward()
grad_L_theta_star = {}
for name,v in self.model_theta_star.named_parameters():
grad_L_theta_star[name] = v.grad # - Store the gradients
v.grad = None
# - The final gradient can be computed using: d L / d theta* * d theta* / d theta + d L / d theta
grad_L_theta = {}
with torch.no_grad():
output_theta_star = self.model_theta_star(X)
loss_rob = self.L_rob(output_theta=self.model_theta(X), output_theta_star=output_theta_star)
loss_rob.backward()
for name,v in self.model_theta.named_parameters():
grad_L_theta[name] = v.grad
v.grad = None
# - Compute d theta* / d theta which is the Jacobian. J is diagonal so we can just keep the shape.
# - See https://arxiv.org/abs/2106.05009
J_diag = { name: (1.0 + v.sign() * (self.initial_std * random_val_dict[name] + \
self.mismatch_level / self.n_attack_steps * sum_signed_grads[name])).detach() \
for name,v in self.model_theta.named_parameters()}
# - Final gradient
final_grad = {name: grad_L_theta_star[name] * J_diag[name] + grad_L_theta[name] for name in J_diag}
return loss_rob.detach(), final_grad
def compute_gradient_and_backward(
self,
model,
X,
y
):
if self.beta_robustness != 0.0:
# - Get the adversarial loss (note: beta_robustness is not applied yet)
adv_loss, adv_loss_gradients = self._adversarial_loss(
model,
X
)
# - Compute the natural loss and backprop
nat_loss = self.natural_loss(model(X), y)
nat_loss.backward()
# - Combine autodiff and numerical gradients
for name,v in model.named_parameters():
v.grad.data += self.beta_robustness * adv_loss_gradients[name]
return nat_loss.detach() + self.beta_robustness * adv_loss
else:
nat_loss = self.natural_loss(model(X), y)
nat_loss.backward()
return nat_loss.detach()
And this is a script that uses it:
# - Deterministic linear layer
from copy import deepcopy
import os
os.environ["CUBLAS_WORKSPACE_CONFIG"]=":4096:8"
import torch
import torchvision
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import DataLoader
# - Import the adversarial loss
from torch_loss import AdversarialLoss
def eval_test_set(
test_dataloader,
net,
):
net.eval()
N_correct = 0
N = 0
for (X,y) in test_dataloader:
X, y = X.to(device), y.to(device)
y_hat = torch.argmax(net(X), axis=1)
N += len(y)
N_correct += (y_hat == y).int().sum()
net.train()
return N_correct / N
def eval_test_set_mismatch(
test_dataloader,
net,
mismatch,
n_reps,
device
):
net_theta_star = deepcopy(net)
test_acc_no_noise = eval_test_set(test_dataloader, net)
test_accs = []
for idx in range(n_reps):
print("Test eval. mismatch rob. %d/%d" % (idx,n_reps))
theta_star = {}
for name,v in net.named_parameters():
theta_star[name] = v + v.abs() * mismatch * torch.randn(size=v.shape, device=device)
net_theta_star.load_state_dict(theta_star)
test_accs.append(eval_test_set(test_dataloader, net_theta_star))
return float(test_acc_no_noise), float(sum(test_accs)/len(test_accs))
def init_weights(lyr):
if isinstance(lyr, (torch.nn.Linear,torch.nn.Conv2d)):
torch.nn.init.xavier_uniform(lyr.weight)
lyr.bias.data.fill_(0.01)
class TorchCNN(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, out_channels=64, kernel_size=(4,4), stride=(1,1), padding="same")
self.pool1 = torch.nn.MaxPool2d(kernel_size=2)
self.conv2 = torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(4,4), stride=(1,1), padding="valid")
self.pool2 = torch.nn.MaxPool2d(kernel_size=2)
self.linear1 = torch.nn.Linear(in_features=1600, out_features=256)
self.linear2 = torch.nn.Linear(in_features=256, out_features=64)
self.linear3 = torch.nn.Linear(in_features=64, out_features=10)
def forward(self, inputs):
x = F.relu(self.conv1(inputs))
x = self.pool1(x)
x = F.relu(self.conv2(x))
x = self.pool2(x)
x = x.view(-1, 1600)
x = F.relu(self.linear1(x))
x = F.relu(self.linear2(x))
x = self.linear3(x)
return x
if __name__ == '__main__':
torch.manual_seed(0)
# - Avoid reprod. issues caused by GPU
torch.use_deterministic_algorithms(True)
# - Select device
device = "cuda" if torch.cuda.is_available() else "cpu"
# - Select which device
if torch.cuda.device_count() == 2:
device = "cuda:1"
# - Fixed parameters
BATCH_SIZE_TRAIN = 100
BATCH_SIZE_TEST = 500
N_EPOCHS = 5
LR = 1e-4
base_dir = os.path.dirname(os.path.abspath(__file__))
download_path = os.path.join(base_dir, "fmnist")
train_set = torchvision.datasets.FashionMNIST(
download_path,
download=True,
transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.5],std=[0.25])])
)
test_set = torchvision.datasets.FashionMNIST(
download_path,
download=True,
train=False,
transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.5],std=[0.25])])
)
train_dataloader = DataLoader(
dataset=train_set,
batch_size=BATCH_SIZE_TRAIN,
shuffle=True,
num_workers=4
)
test_dataloader = DataLoader(
dataset=test_set,
batch_size=BATCH_SIZE_TEST,
shuffle=False,
num_workers=4
)
# - Create Torch network
cnn = TorchCNN().to(device)
cnn.apply(init_weights)
# - Create adam instance for torch
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)
# - Adversarial loss
adv_loss = AdversarialLoss(
model=cnn,
natural_loss=torch.nn.CrossEntropyLoss(reduction="mean"),
robustness_loss=torch.nn.KLDivLoss(reduction="batchmean"),
device=device,
n_attack_steps=10,
mismatch_level=0.025,
initial_std=1e-3,
beta_robustness=0.25
)
for epoch_id in range(N_EPOCHS):
for idx,(X,y) in enumerate(train_dataloader):
X, y = X.to(device), y.to(device)
robustness_loss = adv_loss.compute_gradient_and_backward(
model=cnn,
X=X,
y=y
)
# - Backward does not need to be called
# - Update the weights
optimizer.step()
# - Zero out the grads of the optimizer
optimizer.zero_grad()
if idx % 100 == 0:
test_acc_no_noise, mean_noisy_test_acc = eval_test_set_mismatch(
test_dataloader,
cnn,
mismatch=0.2,
n_reps=5,
device=device
)
print("\n\nTest acc %.5f Mean noisy test acc %.5f" % (test_acc_no_noise,mean_noisy_test_acc))
print("Epoch %d Batch %d/%d Loss %.5f" % (epoch_id,idx,len(train_dataloader),robustness_loss))