Batch normalisation and minibatching

How can I adjust batchnorm2d layers or the final model or the optimizer when minibatching?
For example I have code that looks like this:

batchsize = 64
minibatchsize = 4
nminibatches = (batchsize + minibatchsize - 1) // minibatchsize
optimizer.zero_grad()
for b, (data, target) in enumerate(dataloader):
    out = net(data)
    loss += compute_loss(out,target)
    if (b % nminibatches) == (nminibatches-1):
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        ....

There is a problem with batch normalisation layers and small minibatches. The result of forwarding 16 batches of size 4 is different to forwarding 1 batch of size 64. How can this be corrected?

You could try to adapt the momentum so that the stats from each batch have less influence on the running estimates.
Alternatively, you could also use torch.utils.checkpoint to trade compute for memory and try to increase the batch size.

Maybe a useful feature when inferring a model would be batch size so that when forwarding an input of a given minibatch size, batchnorm layers know to carry on averaging until the total batch size is reached. I think tensorflow for swift does this