Gradient accumulation with BatchNorm2d layers

When I am doing gradient accumulation, the BatchNorm2d layers are not properly accumulated, right? Though, I don’t entirely understand exactly what is going on. The running mean and std deviation are reset on each consecutive batch, correct?

If the BatchNorm2d layers are not properly accumulated, does that also disrupt the gradient accumulation in the accompanying Conv2d layers?

From what I’ve found online the “solution” is to use InstanceNorm2d. I used InstanceNorm2d layers, but my model would not improve and drives predictions to nan even with clamping gradients between (-1,1).

I’m using these layers in a custom UNet model creating segmentation masks.

Does anyone have any insight to this?

The running estimates of all batchnorm layers will use the batch statistics and will not accumulate these stats. The affine parameters of batchnorm layers will still get the accumulated gradients.

The gradients of other layers will not be changed in any way during training, i.e. if you are not using the running batchnorm stats via model.eval().

You could try to change the momentum in all batchnorm layers to make the updates of the running stats smoother or, as you’ve suggested, change this layer for another normalization layer, which might not be too sensitive to small batch sizes.
However, if your batch sizes during the gradient accumulation steps are not too small (e.g. larger than 16 samples), the approach might still work.

2 Likes

Thank you for answering @ptrblck
That would explain a lot of why the model was not improving much… or if at all.

I have a batch size of 16 and am accumulating over 4 batches before passing the gradients to the parameter server for an optimizer step. I can’t increase the batch size anymore due to memory constraints. When I tried InstanceNorm2d, my predictions became nan even with gradient clipping, for some reason that I can’t figure out either. Any insight to why that would be?

How do you change the momentum in all batchnorm layers?

No idea at the moment and you could use torch.autograd.set_detect_anomaly(True) at the beginning of the script to get a stack trace, which should hopefully point to the first NaN creation.

momentum is an argument in all batchnorm layers, which you can directly pass to them:

nn.BatchNorm2d(num_features=3, momentum=0.1)
2 Likes

Then the momentum will be another hyperparameter to adjust or is there a methodological approach to deciding an appropriate value?

I’m not aware of any “automatic” way.
You could try to check the running stats and compare them to the input batch stats to see how noisy they are. However, this would be a manual process.