Once training is done and SWA model is computed, how do you update the batchnorm statistics when using DDP ?
Do you still just use, torch.optim.swa_utils.update_bn(train_loader, swa_model) and the parameters are updated across different ranks?
Once training is done and SWA model is computed, how do you update the batchnorm statistics when using DDP ?
Do you still just use, torch.optim.swa_utils.update_bn(train_loader, swa_model) and the parameters are updated across different ranks?
Since each rank contains a copy of the model I believe you just have to call update_bn
on each rank as you would for a single model and it should work.