Hey guys:
I want to find a way to run batch norm in eval mode for inference without using the running mean and var compute during training.
Indeed in the model I am currently working with the pretrained weights contains unstable batch norm statistics that basically break the model by outputting completely wrong result, and I can’t retrain the model atm. See here for more detail about the issue Model.eval() gives incorrect loss for model with batchnorm layers - #3 by smth.
So my question is how do I forget the running var/mean in eval mode for the batch norm and be able to process batch of size one (for inference)?
class EvilBatchNorm2d(torch.nn.BatchNorm2d):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.running_var = None
self.running_mean = None
Right now I am using this custom batch norm, however putting both running var and mean to None makes the forward to work like in training and thus it need at least of batch of size two. So I added this trick right now. But I am not satisfy with it.
class EvilBatchNorm2d(torch.nn.BatchNorm2d):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.running_var = None
self.running_mean = None
def forward(self, input: Tensor) -> Tensor:
if input.shape[0] == 1:
input = input.expand(2, -1, -1, -1)
return super().forward(input)
Why is the best way to do it ? Alternatively can someone point me to the c++ code for the batch norm forward function so that I can reproduce the correct behavior by myself ?
Thanks in advance