DDP - Batch Norm Issue

Hi,

Could you try disabling the CuDNN backend with:
torch.backends.cudnn.enabled = False? According to posts such as Training performance degrades with DistributedDataParallel, can improve training.

Also, have you given SyncBatchNorm (https://pytorch.org/docs/stable/nn.html#syncbatchnorm) a try? This will make batch statistics be computed across all GPUs in usage, instead of being computed separately for the batches passed to each device. (Note that as per the documentation, you’ll have to change your code to spawn a single process per-GPU if you’re not training that way already)