I am trying to implement Synchronized BatchNorm layer, and I need to modify the Data Parallel
The first step is to gather all inputs of the BatchNorm layer, compute mean and std, then pass it back to the BatchNorm Layer.
But I do not know how to get the feature map of nets on different GPU, and pass the global meab/std back to them.
For example, if we use resnet, the batch norm layers is hidden in the class ‘Bottleneck’, How can I get the output, and pass the mean/std back to them?
On a similar note, how does one get weight norm to work in vanilla/distributed dataparallel? It leaks memory in the distributed case, and does not get transferred to the appropriate gpus in the vanilla case.