Ghost batch norm / Virtual batch size


I was trying to replicate some experiments done in TF and noticed that they use something called virtual batch size. Some papers have shown that the per device batch size and the accuracy of batch norm estimates that comes with it can matter and is often a reason why large batch size training does not perform as well as training with smaller batch sizes. At the same time, training with larger batches, especially on lower dimensional data (eg 32x32 images) often yield better GPU utilization. Is there a way to replicate this ghost batch norm in Pytorch, eg can I have a batch norm layer that automatically subdivided the batch into smaller micro-batches and computes statistics on each individual one? Right now per device batch size is coupled to total batch size and number of GPUs I am using which makes it hard to experiment with it, eg if I want to use a total bs of 1024 and a virtual batch size of 64 I need to use 16 GPUs.
I found one repo that does that but they actually split the batch and perform multiple forward passes which is super inefficient.

Thank’s for your help,