Synchronize batch mean and std of BatchNorm layer in pytorch Data Parallel


I am trying to implement Synchronized BatchNorm layer, and I need to modify the Data Parallel

The first step is to gather all inputs of the BatchNorm layer, compute mean and std, then pass it back to the BatchNorm Layer.

But I do not know how to get the feature map of nets on different GPU, and pass the global meab/std back to them.
For example, if we use resnet, the batch norm layers is hidden in the class ‘Bottleneck’, How can I get the output, and pass the mean/std back to them?

Any one has suggestion?

Hope for help, Thank you.

you should create a custom BatchNorm layer. Maybe call it BatchNormMultiGPU which does such an exchange internally.

It’s not easy to hack this into the existing BatchNorm layer / model definitions.

On a similar note, how does one get weight norm to work in vanilla/distributed dataparallel? It leaks memory in the distributed case, and does not get transferred to the appropriate gpus in the vanilla case.

The Synchronized cross-gpu batch normalization implementation is available here