BatchNorm stats update leads to huge drop in accuracy during eval

I’m currently trying to implement Virtual Adversarial Training (https://arxiv.org/abs/1704.03976) method and test it on CIFAR100.

I came across a weird behavior working on that project. During training, an adversarial perturbation is generated, in both the TensorFlow and Chainer official implementations it is enforced that the batchnorm stats are not updated while generating the adversarial perturbation (which makes sense since those can deviate significantly from the true data distribution). Using PyTorch, what I thought was the best way to accomplish this was the following:

@staticmethod
def set_bn_eval(m):
   if isinstance(m, nn.modules.batchnorm._BatchNorm):
      m.eval()

@staticmethod
def set_bn_train(m):
   if isinstance(m, nn.modules.batchnorm._BatchNorm):
      m.train()

And then running

self.trainer.model.apply(self.set_bn_eval) # disable batch stats update

And

self.trainer.model.apply(self.set_bn_train)

Respectively before and after the adversarial perturbation is generated.

The training goes well and I obtain around 30% accuracy after the first epoch however when switching to eval mode and processing the validation set the accuracy drops massively (even to 1-2%). For further epochs training accuracy keeps increasing while validation remains extremely low.

First, I did not have that discrepancy between train and eval when not using VAT. I checked that simply the model in train mode mode during validation solves the issue which confirms it is related to a change between train and eval behavior. I also checked that when switching to eval mode but forcing BN layers to use batch statistics instead of the running ones (by setting them to None and setting track running stats to False) then the performance is again matching that of training.

Hence my only conclusion is that the BatchNorm layers are indeed updating their running stats during the adversarial perturbation generation which is extremely unsettling because I checked by hand that the set_bn_eval and set_bn_train were working as expected and that running stats were not updated when set_bn_eval had been apply.

The link to my repo is: https://github.com/laurent3577/Image-Classification

I’m really puzzled as to what exactly is going on there and would appreciate any help.