Aligning nn.BatchNormNd with nn.Linear

The nn.Linear layer transforms shape in the form (N,*,in_features) -> (N,*,out_features). I have found myself multiple times trying to apply batch normalization after a linear layer. However, because the default nn.BatchNormNd layers only apply over the dimension 1 (corresponding to channels in the convolutional layers), I can only directly compose nn.Linear and nn.BatchNormNd if there are no additional dimensions input into nn.Linear. If there are more dimensions, I have to transpose the last dimension into the channel dimension before and after the batch normalization. Is there a better way to do this?

2 Likes

It’s interesting that you want to normalize for each slice in the last dimension, especially considering that it is from a nn.Linear output. May I ask why?

Upon some further investigation I realized my conception of batch normalization was incorrect, and it should be applied sample-wise following dense layers. However, PyTorch does not help this confusion by labeling the output shape parameter num_features, which led me to believe batch normalization is applied feature-wise for both convolutional and dense layers.

Yes the name is definitely a bit confusing.

Are you saying that you want to normalize over all but the batch dimension, i.e. each sample is normalized with mean and std computed from that sample? If so, you should use Layer Normalization, which will become available in next release. To achieve the same effects in 0.3.1, you can feed input.unsqueeze(0) into a BN layer.

Jeeze, this is what I get for jumping in trying to implement white-paper models without putting in the time to fully understand the requisite components. I think what you suggested is correct, but to clarify, is it the case that statistics are computed along slices of the channel dimension and the mini-batch dimension for convolutional layers, but only slices of the mini-batch dimension for dense layers?

Here is what BN does:

assumes that inputs has shape:  [B x C x *]
  B is minibatch size
  C is the size of second dimension. In image data, it is the channel 
    dim. But generally it is the second dim.
  * is arbitrary number of dimensions of arbitrary sizes

Normalizes each slice input[:, i, :] for i = 0, ..., C - 1, that is
normalizing the slice using the mean and std computed from that slice.

So if you have input of shape [B, *], and do bn_layer(input.unsqueeze(0)), you are effectively doing:

input_new = input.unsqueeze(0) of shape [1, B, *]
output = bn(input_new) normalizes each input_new[:, i] = input[i] slice
  for i = 0, ..., B - 1.
output = output.squeeze(0) gives what you want
1 Like