Gradient Accumulation with BatchNorm

I am implementing a paper that requires high batch size, but since I only have 1 GPU, I’m trying to use gradient accumulation to solve the problem.

However, I think gradient accumulation might be a problem when using BatchNorm, as the momentum isn’t accumulated. Is there a workaround to this problem rather than using other normalizations? (e.g. GroupNorm, InstanceNorm) Thanks!

1 Like

You could try to use torch.utils.checkpoint to trade compute for memory.
This approach would basically do some recalculations in order to avoid storing intermediate activations.