SWA for distributed training

Hi there,

I am very new to pytorch. Is Stochastic Weight Averaging supported in distributed training (more specifically, the update_bn function).

Thanks!

I see no reason why it wouldn’t work. Since SWA holds a separate copy of the model on each rank and uses the same all-reduced gradients to update the weight averages, it should theoretically work.

I would suggest giving it a try, and letting us know if you experience any issues.

I have tried SWA with DDP and encounter key errors when I try to load the state_dict of the saved swa model.

Below is my skeletal code:


from torch.optim.swa_utils import AveragedModel, SWALR
from torch.optim.lr_scheduler import CosineAnnealingLR

loader, optimizer, model, loss_fn = ...
model = torch.nn.parallel.DistributedDataParallel(module=model, broadcast_buffers=False,  device_ids=[local_rank])
model.train()
swa_model = AveragedModel(model)
scheduler = CosineAnnealingLR(optimizer, T_max=100)
swa_start = 5
swa_scheduler = SWALR(optimizer, swa_lr=0.05)

for epoch in range(100):
      for input, target in loader:
          optimizer.zero_grad()
          loss_fn(model(input), target).backward()
          optimizer.step()
      if epoch > swa_start:
          swa_model.update_parameters(model)
          swa_scheduler.step()
      else:
          scheduler.step()

# Update bn statistics for the swa_model at the end
torch.optim.swa_utils.update_bn(loader, swa_model)
#saving swa_model
torch.save(swa_model.module.state_dict(),  'swa_model.pth'))
#saving model at last epoch
torch.save(model.module.state_dict(),  'last_epoch_model.pth'))
# The code runs fine and the model is saved
#During inference

model = get_model()
model.load_state_dict(torch.load('swa_model.pth'))

Now there is a mis-match of keys in the state dictionary
Because I have used DDP and passed the DDP model to AverageModels - Results in the following error:
Missing key(s) in state_dict: “conv1.weight”…; Unexpected key(s) in state_dict: “module.conv1.weight”…

But “last_epoch_model.pth”, loads correctly. Any thoughts?

When wrapping your model with both DDP and AverageModel, you need to call swa_model.module.module to get back to the correct level. Right now, with only one “.module”, you’re saving a DDP wrapped model, which is the reason for the unexpected ‘module’-prepend in your state_dict keys when loading.