Stochastic Weight Averaging

I was hoping to start using the SWA implementation in torchcontrib. Two questions:

  • Does it matter if swa.bn_update is before or after opt.swap_swa_sgd?
  • If we want to save checkpoints, do we need to execute the update functions above?
3 Likes

Hello! Have you found the answer?

Hello, I am dealing with the same issues here. Did you find the solution? And a related question:
Does calling swa.bn_update affect the training when we retrain the model?

So, to answer my own question. According to the (pending) contribution to the pytorch library by the author here, the documentation of the newly added class function swap_swa_sgd_update_bn says

    def swap_swa_sgd_update_bn(self, loader, model, device=None):
        r"""Swaps variables and swa buffers and updates BatchNorm buffers.
        This method is equivalent to subsequently calling 
        :meth:`opt.swap_swa_sgd()` and 
        :meth:`opt.bn_update(loader, model, device)`. It's meant to be called in
        the end of training to use the collected swa running averages and update
        the Batch Normalization running activation statistics in the model. It 
        can also be used to evaluate the SWA running averages during training; 
        to continue training `swap_swa_sgd` should be called again.
        """
        self.swap_swa_sgd()
        self.bn_update(loader, model, device)