How do you update batchnorm statistics of a SWA model when using DDP?

Once training is done and SWA model is computed, how do you update the batchnorm statistics when using DDP ?

Do you still just use, torch.optim.swa_utils.update_bn(train_loader, swa_model) and the parameters are updated across different ranks?

Since each rank contains a copy of the model I believe you just have to call update_bn on each rank as you would for a single model and it should work.