Error in SWA Blogpost

In the blogpost titled “PyTorch 1.6 now includes Stochastic Weight Averaging” there is a section of code that does not match up with its description.

The blogpost can be found HERE.

Under the section “HOW TO USE SWA IN PYTORCH?” there is a description: “In the example below, swa_model is the SWA model that accumulates the averages of the weights. We train the model for a total of 300 epochs, and we switch to the SWA learning rate schedule and start to collect SWA averages of the parameters at epoch 160.”

However, the example code seems to train only for 100 epochs and starts collecting SWA averages after epoch 5. I wanted to see if this was an error or if I am misunderstanding the code.

Code shown in the blogpost:

from torch.optim.swa_utils import AveragedModel, SWALR
from torch.optim.lr_scheduler import CosineAnnealingLR

loader, optimizer, model, loss_fn = ...
swa_model = AveragedModel(model)
scheduler = CosineAnnealingLR(optimizer, T_max=100)
swa_start = 5
swa_scheduler = SWALR(optimizer, swa_lr=0.05)

for epoch in range(100):
      for input, target in loader:
          optimizer.zero_grad()
          loss_fn(model(input), target).backward()
          optimizer.step()
      if epoch > swa_start:
          swa_model.update_parameters(model)
          swa_scheduler.step()
      else:
          scheduler.step()

# Update bn statistics for the swa_model at the end
torch.optim.swa_utils.update_bn(loader, swa_model)
# Use swa_model to make predictions on test data 
preds = swa_model(test_input)

This looks indeed not right. CC @vincentqb

1 Like

Thanks for catching this! Yes, this is a typo in the example of the blog, see documentation :slight_smile: I’ll look into getting the post updated.

1 Like