Example on how to use batch-norm?

TLDR: What exact size should I give the batch_norm layer here if I want to apply it to a CNN? output? In what format?

I have a two-fold question:

  1. So far I have only this link here, that shows how to use batch-norm. My first question is, is this the proper way of usage? For example

bn1 = nn.BatchNorm2d(what_size_here_exactly?, eps=1e-05, momentum=0.1, affine=True)
x1= bn1(nn.Conv2d(blah blah blah))

Is this the correct intended usage? Maybe an example of the syntax for it’s usage with a CNN?

  1. I know that there are sometimes caveats with usage of batch-norm during training and inference time - (for example, the original paper will compute running averages and variances of the training data AFTER the net has fully trained, and then use that in the inference equation), however I am guessing the batch-norm usage in pyTorch already does this under the hood, and so I can call forward_prop at test time the same way I would call it at train time?



Hi, as for 1st question,

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=10,
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_bn = nn.BatchNorm2d(20)
        self.dense1 = nn.Linear(in_features=320, out_features=50)
        self.dense1_bn = nn.BatchNorm1d(50)
        self.dense2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_bn(self.conv2(x)), 2))
        x = x.view(-1, 320) #reshape
        x = F.relu(self.dense1_bn(self.dense1(x)))
        x = F.relu(self.dense2(x))
        return F.log_softmax(x)



@moskomule Thanks! I will try this.

On an unrelated note, one thing I noticed though is that you are doing ReLUs AFTER your max_pool, however the canonical way it is done is usually the reverse. (ReLU and then max-pool).


Thanks for your advice, in fact I just modified PyTorch’s example.

  1. As @moskomule pointed out, you have to specify how many feature channels will your input have (because that’s the number of BatchNorm parameters). Batch and spatial dimensions don’t matter.

  2. BatchNorm will only update the running averages in train mode, so if you want the model to keep updating them in test time, you will have to keep BatchNorm modules in the training mode. See the C implementation for details (it should be readable).

  3. About ReLU and MaxPool - if you think about it for a moment both ReLU + MaxPool and MaxPool + ReLU are equivalent operations, with the second option being 37.5% more efficient (numel + numel in first case numel + numel/4 in the second case, where numel is the number of elements in the tensor). That’s why the example has a different order.


@apaszke Thanks for the batchnorm info!

For the ReLUs/Max-pool, I agree that in this case they are equivalent, however in the more general case of an arbitrary activation function they will not necessarily be - consider a “relu” activation function that is flipped and exists on the second and forth quadrants instead of just the first and border between third and second, in this case the two operation orders would matter. (Granted they are not that popular, I just wanted to point out the subtlety :wink: )


Sure, it’s not a general thing of course, it’s only leveraging a property of the max operator.

What do you do during test time? How do you set the forward prop so that it does not update the weights of the batch_norm module? Via eval()

@Kalamaya do you want to freeze the running averages or the weights?

At test time, I would like to freeze both the weights, (lambda and beta), as well as freeze the running averages that is has computed. (Ostensibly because it has a good estimate for those from training already).

So I basically expect that I would want all 4 of those values frozen.

Yeah in that case if you keep the BatchNorm modules in evaluation mode, and you won’t pass their parameters to the optimizer (best to set their requires_grad to False), they will be completely frozen.

1 Like

@apaszke Would it be possible to provide an example of the training time vs. test time usage of BatchNorm?


1 Like

@apaszke How to keep the BatchNorm modules in evaluation mode during training?
Using bn.train(False) in the init() seems not working.
Thanks in advance.

You may want to freez the BatchNorm params by settng their requires_grad to False.

Setting requires_grad to False only freezes the parameters, but not the case for the moving averages.
I need to freeze them all.

you need to set the BatchNorm layers to eval() mode.

I have a pretrained model whose parameters are available as csv files. This model has batch norm layers which has got weight, bias, mean and variance parameters. I want to copy these parameters to layers of a similar model I have created in pytorch. But the Batch norm layer in pytorch has only two parameters namely weight and bias. How do I deal with mean and variance so that during eval all these four parameters are used?

How to set the BatchNorm layers to eval() mode?

1 Like

I have been using BN to train audoencoders over a large number of image patches (50K/image) of different architectures recently. There is indeed gotcha whenever BN is used with the dataset as follows.

After training a long time (70 epochs or more with 4K batches each), the validation loss suddenly increases significantly and never comes back while the training loss remains stable. Decreasing the learning rate only postpones the phenomenon. The trained model at this point is not usable if model.eval() is called as it is supposed to be. But if the output is normalized to the regular pixel range, the results seem alright. After several trials, the cause is likely the default epsilon which may be too small (1e-5) for long term stability. However, increasing the epsilon leads to a slightly higher validation loss. Alternatively, the running mean and var computed by pytorch under the hood may have something to do with since fixing BN at the training mode also alleviates the issue for inference time.

In view of the popularity of BN, I am a bit surprised everyone seems happy with it but is it really the case?