SWA AveragedModel proper usage

SWA AveragedModel proper usage

Hi, I want to use SWA technique and finally is official at Pytorch! https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging/

But there is something not clear. In the example they post, the averagemodel is created out of the initial training loop. My question is… If we want to average from last model weights, as should be and not initial random weights, we should create the averagemodel just before swa starts, something like this (following example nomenclature):

if (epoch+1) == swa_start:
    swa_model = AveragedModel(model)
1 Like

Up… Not found solution :confused: Modifying the page example. I think should be like:

from torch.optim.swa_utils import AveragedModel, SWALR
from torch.optim.lr_scheduler import CosineAnnealingLR

loader, optimizer, model, loss_fn = ...
swa_model = None
scheduler = CosineAnnealingLR(optimizer, T_max=100)
swa_start = 5
swa_scheduler = SWALR(optimizer, swa_lr=0.05)

for epoch in range(100):
      for input, target in loader:
          optimizer.zero_grad()
          loss_fn(model(input), target).backward()
          optimizer.step()
      if epoch > swa_start:
          if swa_model is None:
              swa_model = AveragedModel(model)
          else:
              swa_model.update_parameters(model)
              swa_scheduler.step()
      else:
          scheduler.step()

# Update bn statistics for the swa_model at the end
torch.optim.swa_utils.update_bn(loader, swa_model)
# Use swa_model to make predictions on test data 
preds = swa_model(test_input)
2 Likes

Perhaps @ptrblck can help about this?

1 Like

I think your code is correct and the initial “checkpoint” would be created after swa_start epochs were already done. Afterwards the update would take place.
I’m unsure, if this is “necessary” or if the posted example would also work fine, since the initial updates would become less important during the training.

Did you run your code and saw a difference using both approaches?
If so, which one performed better?

I did like what you did, and it seems you are right.
I checked right after the swa_model = AveragedModel(model), the weights of the swa_model parameters and model parameters are exactly the same. Doing swa_model = AveragedModel(model) before the loop starts is not a good idea.

1 Like

I wonder how to insert validation phase code?

1 Like

@Mario_Parreno @ptrblck
Do we also need to update the BN stats before validation?
Given that we do need it for testing, I suspect we also need it for validation, correct?

The batchnorm stats will be updated by default during training (the model is by default in training mode or you can additionally call model.train()) while these running stats will be used during validation after calling model.eval(). I’m unsure how you would like to update these stats so could you explain the use case and question a bit more?

Sure! So in SWA two models are maintained: the model and the swa_model. The latter is the averaged model. What we “train” is model (we backprop this, update its weights, etc.), and only every now and then we update swa_model with model by averaging. That’s why the batchnorm stats in swa_model needs separate updating. From the Pytorch website:

One important detail is the batch normalization. Batch normalization layers compute running statistics of activations during training. Note that the SWA averages of the weights are never used to make predictions during training. So the batch normalization layers do not have the activation statistics computed at the end of training. We can compute these statistics by doing a single forward pass on the train data with the SWA model.

And here’s the code snippet provided by Pytorch, where we see BN stats of swa_model being updated at the end right before testing:

for epoch in range(100):
      for input, target in loader:
          optimizer.zero_grad()
          loss_fn(model(input), target).backward()
          optimizer.step()
      if epoch > swa_start:
          swa_model.update_parameters(model)
          swa_scheduler.step()
      else:
          scheduler.step()

# Update bn statistics for the swa_model at the end
torch.optim.swa_utils.update_bn(loader, swa_model)
# Use swa_model to make predictions on test data 
preds = swa_model(test_input)

The code snippet doesn’t show the validation step. So my question was whether we should also update_bn() the swa_model for validating it.

Thanks for clarifying the question!
Using the provided code snippet to update the batchnorm stats using the training dataset sounds right. I wouldn’t update the stats using the validation dataset, as I would consider it a data leak (similar to updating them without the SWA util).

Yes, I wouldn’t use the validation dataset. What I meant is to update the BN stats using the training set before running validation, similar to the code example above (loader is the training set).

Ok so looking at swa/train.py at 411b2fcad59bec60c6c9eb1eb19ab906540e5ea2 · timgaripov/swa · GitHub I think I was right in that we need to update the BN stats of the swa_model before we validate. Relevant lines 158, 159 @hktxt @ptrblck