BatchNorm in eval without running var and mean buffer

Hey guys:

I want to find a way to run batch norm in eval mode for inference without using the running mean and var compute during training.

Indeed in the model I am currently working with the pretrained weights contains unstable batch norm statistics that basically break the model by outputting completely wrong result, and I can’t retrain the model atm. See here for more detail about the issue Model.eval() gives incorrect loss for model with batchnorm layers - #3 by smth.

So my question is how do I forget the running var/mean in eval mode for the batch norm and be able to process batch of size one (for inference)?

class EvilBatchNorm2d(torch.nn.BatchNorm2d):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.running_var = None
        self.running_mean = None

Right now I am using this custom batch norm, however putting both running var and mean to None makes the forward to work like in training and thus it need at least of batch of size two. So I added this trick right now. But I am not satisfy with it.

class EvilBatchNorm2d(torch.nn.BatchNorm2d):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.running_var = None
        self.running_mean = None

    def forward(self, input: Tensor) -> Tensor:
        if input.shape[0] == 1:
            input = input.expand(2, -1, -1, -1)

        return super().forward(input)

Why is the best way to do it ? Alternatively can someone point me to the c++ code for the batch norm forward function so that I can reproduce the correct behavior by myself ?

Thanks in advance

I’m unsure what the desired behavior would be in your use case.
If you don’t want to use the running stats at all, you could use track_running_stats=False such that the batch stats will be used to normalize the input activation during training and evaluation.
However, based on your explanation it also seems that you are using a single sample and thus the stats calculation would fail.
In this case you could reset the running stats to zero for the mean and ones for the stddev, but I would expect the achieved performance would be bad since you are not normalizing the inputs at all and are thus potentially changing the activation range completely.