I am trying to implement the paper Adversarial Examples Improve Image Recognition. A main component is having branching where some samples go through one branch and others go through another.
What I have so far is something like this, starting from an existing model I first replace the BN layers as such:
class MyBatchNorm2d(nn.Module):
def __init__(self, init_bn):
super(MyBatchNorm2d, self).__init__()
self.bn = copy.deepcopy(init_bn)
# Aux BN
self.bn_aux = copy.deepcopy(init_bn)
self.bn_aux.reset_parameters()
self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
def forward(self, input):
if self.training:
self.num_batches_tracked = self.num_batches_tracked + 1
# Odd runs are AUX input
if self.training and self.num_batches_tracked % 2 != 0:
return self.bn_aux(input)
return self.bn(input)
def convert_bn_to_bnaux(model):
global num_batches_tracked
for child_name, child in model.named_children():
if isinstance(child, nn.BatchNorm2d):
setattr(model, child_name, MyBatchNorm2d(copy.deepcopy(child)))
del child
else:
convert_bn_to_bnaux(child)
Then while training I do the following:
for i, sample_batched in enumerate(train_loader):
# Switch to train mode
model.train()
optimizer.zero_grad()
# Predict
output = model(image)
output_aux = model(adv_image) # some adversarial image
# Compute the loss
loss = l1_criterion(output, labels)
# AUX
loss_aux = l1_criterion(output_aux, labels)
loss_both = loss + loss_aux
# Update step
loss_both.backward()
optimizer.step()
However, I am not sure if it the back propagation is correct in this case. Do I need to branch on the backward pass?
I tried to diagnose it by flipping the two training steps to try and mess up the results but I still get the same test error.
The backward pass will be executed according to what Autograd collected during the forward pass, so the condition should be respected.
I’m not sure to understand your code completely, but it seems you are initializing self.bn and self.bn_aux with the same values? If so, switching the order wouldn’t make a difference?
I think it is working now. I tested it by running one branch on the original image and the other on the negative (-image) and the stats are now negative of each other after many iterations.
@Sara_Rojas_Martinez I ended up with this code that comes with helper debug functions. Not sure if its correct as my experiments ended up learning two distribution equally well and not have a compounded advantage as shown in the paper.