[Module fixing] Batch Normalisation for batch size 1 and feature size 1 should be fixed for training

Hi all,

I found that PyTorch currently does not have a solution to batch normalisation under batch size of 1 and feature size of 1. Example that will cause error:

    batch = 1
    channels = 3
    feature_size = 1
    in_ = torch.ones((batch,channels,feature_size,feature_size))
    norm  = nn.BatchNorm2d(channels)
    y = norm(in_)

Error will be raised (which I know makes sense as there is batch size of 1 has nothing to normalised with). However, there are many potential situation in training where you will load a single image as the last batch. It seems PyTorch has already allow batch size of 1 as long as we increase the feature size, example:

    batch = 1
    channels = 3
    feature_size = 2
    in_ = torch.ones((batch,channels,feature_size,feature_size))
    norm  = nn.BatchNorm2d(channels)
    y = norm(in_)

No error will be raised. Why don’t they do the same for feature_size=1?

If you put norm.eval(), there will be no issue. However, I am asking mainly for training.

Another way around it is to set training DataLoader drop_last=True to prevent this but some user with different GPU memory size will use different batch size and not realise what error is hitting them especially when they are just the user.

Edit: *user of another open-source repository, this is evident in a few pytorch open-source projects open issues.

Because a larger feature size would allow the batchnorm layer to calculate the stddev from these values, while a batch size and feature dimension of 1 would have an undefined (or invalid) stddev.

How would you calculate the stddev in the case of a single value? Would you rather return NaNs in the forward pass?

Could the internal batch norm function adds an if clause for feature size of 1 to forward pass like nn.Identity()?

I would be a bit skeptical about this workaround. While it would avoid the runtime error, I would expect to see a hard runtime error as a user if I want to use a layer, which cannot perform the mathematical operation internally as the input shape doesn’t allow it.

The overall use case in a batchnorm layer is:

out = (input - mean) / stddev * weight + bias

Using your suggestion you would also skip the affine parameters, which would thus also not be trained.
On the other hand, you could also skip the mean subtraction and division by stddev, which would then however yield a different activation range and could potentially diverge your model training.

Hi @ptrblck ,

I am not very familiar on what will hold the pytorch codebase back from just literally having batch normalisation do nothing (no trainable parameter) when batch_size=1 and feature_map_size=1 for training.

I think users will be unaware of the problem as the training will work most of the time and then suddenly fail (ex: DataLoader loads one final image for training). This is also painful for users trying to train on a big dataset as they have to go through a lot of iterations just to get hit by the batch normalisation error caused by the final data batch being 1 (assuming in the networks, there is a module that pools to a feature size=1).

I would claim a valid runtime error raised by an invalid mathematical operation is way less painful than an unexpected behavior by the framework, which would change the behavior of the model and could potentially diverge your training.
Debugging such divergence issues can take many more weeks than to fix a shape issue.
As described before, I would see the workaround using the nn.Identity module as a silent numerical error, since the output would not be normalized at all and could have a different range and thus cause divergence.

Here is a small example:

bn = nn.BatchNorm1d(3)

x = torch.randn(2, 3, 1) * 1000.
out = bn(x)
print(out.min(), out.max(), out.mean())
# > tensor(-1., grad_fn=<MinBackward1>) tensor(1., grad_fn=<MaxBackward1>) tensor(0., grad_fn=<MeanBackward0>)

# workaround with Identity 
bn = nn.Identity()
out = bn(x)
print(out.min(), out.max(), out.mean())
# > tensor(-1368.4515) tensor(1302.6108) tensor(75.8487)

x is the incoming activation with a large value range (randn * 1000.). In the working case the batchnorm layer would normalize it and the following layer (e.g. a linear layer) would expect to see these normalized values. If you are now skipping the norm the following layer would get the unnormalized values with an entirely different range.

You could see similar effects by using a pretrained model (e.g. use torchvision.models.resnet18) and by comparing the achieved accuracy using normalized input images (expected) vs. unnormalized inputs with values in the range [0, 255]. In the latter case the model would not perform well even though it’s a proper pretrained model.