Auxiliary batch norm for AdvProp

I am trying to implement the paper Adversarial Examples Improve Image Recognition. A main component is having branching where some samples go through one branch and others go through another.

What I have so far is something like this, starting from an existing model I first replace the BN layers as such:

class MyBatchNorm2d(nn.Module):
    def __init__(self, init_bn):
        super(MyBatchNorm2d, self).__init__()
        self.bn = copy.deepcopy(init_bn)
        
        # Aux BN
        self.bn_aux = copy.deepcopy(init_bn)
        self.bn_aux.reset_parameters()
        
        self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
        
    def forward(self, input):            
        if self.training:
            self.num_batches_tracked = self.num_batches_tracked + 1            
        
        # Odd runs are AUX input
        if self.training and self.num_batches_tracked % 2 != 0:
            return self.bn_aux(input)
      
        return self.bn(input)

def convert_bn_to_bnaux(model):
    global num_batches_tracked
    for child_name, child in model.named_children():
        if isinstance(child, nn.BatchNorm2d):
            setattr(model, child_name, MyBatchNorm2d(copy.deepcopy(child)))
            del child
        else:
            convert_bn_to_bnaux(child)

Then while training I do the following:

    for i, sample_batched in enumerate(train_loader):
        # Switch to train mode
        model.train()

        optimizer.zero_grad()

        # Predict
        output = model(image)
        output_aux = model(adv_image) # some adversarial image

        # Compute the loss
        loss = l1_criterion(output, labels)

        # AUX
        loss_aux = l1_criterion(output_aux, labels)

        loss_both = loss + loss_aux

        # Update step
        loss_both.backward()
        optimizer.step()

However, I am not sure if it the back propagation is correct in this case. Do I need to branch on the backward pass?

I tried to diagnose it by flipping the two training steps to try and mess up the results but I still get the same test error.

The backward pass will be executed according to what Autograd collected during the forward pass, so the condition should be respected.

I’m not sure to understand your code completely, but it seems you are initializing self.bn and self.bn_aux with the same values? If so, switching the order wouldn’t make a difference?

Thanks for the feedback @ptrblck!

For the initial values, I call this function on one of them. Does it not zero out the BN parameters.

       self.bn_aux.reset_parameters()

Yes, you are right and I have assumed you are passing an already reset bn layer into your custom implementation.

Your MyBatchNorm2d implementation seems to work correctly:

class MyBatchNorm2d(nn.Module):
    def __init__(self, init_bn):
        super(MyBatchNorm2d, self).__init__()
        self.bn = copy.deepcopy(init_bn)
        
        # Aux BN
        self.bn_aux = copy.deepcopy(init_bn)
        self.bn_aux.reset_parameters()
        
        self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
        
    def forward(self, input):            
        if self.training:
            self.num_batches_tracked = self.num_batches_tracked + 1            
        
        # Odd runs are AUX input
        if self.training and self.num_batches_tracked % 2 != 0:
            return self.bn_aux(input)
      
        return self.bn(input)


bn = nn.BatchNorm2d(3)

with torch.no_grad():
    bn.weight.add_(10)
    bn.bias.add_(10)
    bn.running_mean.add_(10)
    bn.running_var.add_(10)

my_bn = MyBatchNorm2d(copy.deepcopy(bn))

x = torch.randn(2, 3, 4, 4)
out1 = bn(x)
out2 = my_bn(x)
print(out1)
print(out2)

How are you comparing the results using the test error?

I think it is working now. I tested it by running one branch on the original image and the other on the negative (-image) and the stats are now negative of each other after many iterations.

Hello,
@ialhashim Could you implement Advprop? I’m also having issues. :frowning:

@Sara_Rojas_Martinez I ended up with this code that comes with helper debug functions. Not sure if its correct as my experiments ended up learning two distribution equally well and not have a compounded advantage as shown in the paper.

class MyBatchNorm2d(nn.Module):
    def __init__(self, init_bn, isAuxBN):
        super(MyBatchNorm2d, self).__init__()
        self.bn = copy.deepcopy(init_bn)
        
        # Aux BN
        self.bn_aux = copy.deepcopy(init_bn)
        
        self.isAuxBN = isAuxBN
        self.isForwardAuxBN = False
        
        self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
        
    def forward(self, input):            
        if self.training:
            self.num_batches_tracked = self.num_batches_tracked + 1            
        
        # Odd runs are AUX input
        if self.isAuxBN:
            if self.training and self.num_batches_tracked % 2 != 0:
                return self.bn_aux(input)
            
            if self.isForwardAuxBN:
                return self.bn_aux(input)
        
        return self.bn(input)
    
    def run_stats(self):
        return '\n\nParent {} \t bn {} \t bn_aux {}\n\n'.format(self.num_batches_tracked, self.bn.num_batches_tracked, self.bn_aux.num_batches_tracked)
        
    def bn_stats(self):
        return 'BN mean\n {}, \n\n'.format(self.bn.running_mean)
    
    def bn_aux_stats(self):
        return 'BN-AUX mean\n {}, \n'.format(self.bn_aux.running_mean)
    
    def __repr__(self):
        parent_string = super().__repr__()
        txt = parent_string + self.run_stats()
        txt = txt + self.bn_stats()
        txt = txt + self.bn_aux_stats() 
        return txt + '(is forward using AUX {})'.format(self.isForwardAuxBN)

hi @ialhashim, how did you generate the adversarial examples?
Have you managed to apply PGD?