Using the same BatchNorm for several layers?

Hi everyone,

I have a several layers of CNNs operating on the same size tensor. Can I use a single BatchNorm for all of them, or do I need to set a different BatchNorm per layer?

In other words, in my init() can I do:

self.batchNorm = nn.BatchNorm1d(128, affine=True)
self.conv1 = nn.Conv1d(in_channels = 128, out_channels = 128, kernel_size = 3, padding = 1)
self.conv2 = nn.Conv1d(in_channels = 128, out_channels = 128, kernel_size = 3, padding = 1)
self.conv3 = nn.Conv1d(in_channels = 128, out_channels = 128, kernel_size = 3, padding = 1)

and in my forward pass can I do:

x = self.conv1(x)
x = self.batchNorm(x)        
x = F.relu(x)

x = self.conv2(x)
x = self.batchNorm(x)        
x = F.relu(x)

x = self.conv3(x)
x = self.batchNorm(x)        
x = F.relu(x)

Or do I need to declare batchNorm1, 2, and 3 for each one of these conv layers? I don’t know how it is implemented, so I don’t know if by reusing the same BatchNorm for different layers it will screw up the statistics it uses for computing the batch norm.

Thanks for your help!

Also, some follow up questions:

The batchNorm usually goes BEFORE the activation function (ReLU), correct?

Also, is the argument in the BatchNorm the size of the input, or the number channels, or the batch size? Thanks!

Hi,

Batchnorm actually has some learnable parameters. So if you re-use it, it will share these learnt parameters across the different use.
Also it is saving the running stats to be used in evaluation mode. And so if you re-use it, these stats will be shared.
It might or might not be what you want, depending on your use case.

You can check the doc for the detail on how to use each argument: https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html#torch.nn.BatchNorm2d

Thanks for the response! This answers my question.

So I guess I need to have different BatchNorm() statements for each of the CNNs for two reasons: 1) there are learnable parameters \alpha and \beta that might be different from layer to layer, 2) it seems that the BatchNorm stores the batch mean and the variance somewhere so that it can be used at run-time. Where exactly is that stored though? When I save out the model using torch.save(net.state_dict(), filename) is it dumped out along with the network weights?

Also, I can’t understand why I only need to input the number of channels and not the length of the signal (in the 1d case) or W,H (in the 2D case). Any ideas?

Hi,

Yes it is saved along as they Buffers.

You don’t need to provide the signal length as the model does not need to know it. This dimension is reduced when computing the stats.