Batch Norm Broadcasting

When working with vectorial data, I sometimes need to leave the batch x dimension format in favour of batch x samples x dimension.

The sampling dimension can originate from the latent space of importance weighted autoencoders, or just multiple measurements of the same instance during an experiment.

While linear layers easily broadcast along the additional dimensions, batch norm layers need special care here.

My work around about the missing bach norm broadcasting is a double reshape
from reshape(batch * samples x dimension) => batch_norm(...) => reshape(batch, samples, dimension

Is there a way to extend batch norm without the supposedly inefficient reshaping?

For illustration

import torch
import torch.nn as nn

lin = nn.Linear(10, 10)

b, p, d = 16, 80, 10
x = torch.randn(b, p, d)
x_ = x.reshape(b * p, d)

o = lin(x)
o_ = lin(x_)

print(torch.allclose(o, o_.reshape(b, p, d)))
">> yay :)"


bn = nn.BatchNorm1d(10)
bn(x_) # is ok
bn(x)
">> nay :?"
"RuntimeError: running_mean should contain 80 elements not 10"

Batchnorm layers use running stats internally and expect the input channels to match their dimension.
The x input is using 80 channels (dim1), which raises the error and I’m unsure which workaround should be used instead.
Could you explain your idea a bit more?

@ptrblck thanks for pointing out the channel dimension, I revisited the docstring. I only had vectorial data (N x L input size) in mind, and did not think about the use case of convolution with (N x C x L input size).

The current behaviour of batch norm on tensors (N x C x L) is designed to conform with what should be used with 1D convolutional outputs. That is, each channel will get its own mean and var, right? And all elements in the channel will be standardised with the same mean, var pair.

I can then reframe my issue:
The default behaviour of BatchNorm on tensors (instead of matrices), does not broadcast the matrix operation (element wise batch norm), but completely switches to a different kind operation (channel wise batch norm).

I appreciate the fact, that this may be a sound default for the many users working with CNNs (including me :slight_smile: ). I however see many cases, where an actual broadcasting of the BatchNorm layer to higher order tensors is desired. The broadcasting would then be identical to how linear layers cope with higher dimensional tensors (only apply to last dimension, broadcast else).

Maybe this is too didactic, but it could be helpful, if the library explicates the difference of applying an elementwise batch normalization to a channel wise batch normalization by setting a default flag.

Based on your description it seems you might be looking for LayerNorm?
From the docs:

Unlike Batch Normalization and Instance Normalization, which applies scalar scale and bias for each entire channel/plane with the affine option, Layer Normalization applies per-element scale and bias with elementwise_affine .