I was hoping to start using the SWA implementation in torchcontrib. Two questions:
- Does it matter if
swa.bn_update
is before or afteropt.swap_swa_sgd
? - If we want to save checkpoints, do we need to execute the update functions above?
I was hoping to start using the SWA implementation in torchcontrib. Two questions:
swa.bn_update
is before or after opt.swap_swa_sgd
?Hello! Have you found the answer?
Hello, I am dealing with the same issues here. Did you find the solution? And a related question:
Does calling swa.bn_update
affect the training when we retrain the model?
So, to answer my own question. According to the (pending) contribution to the pytorch library by the author here, the documentation of the newly added class function swap_swa_sgd_update_bn
says
def swap_swa_sgd_update_bn(self, loader, model, device=None):
r"""Swaps variables and swa buffers and updates BatchNorm buffers.
This method is equivalent to subsequently calling
:meth:`opt.swap_swa_sgd()` and
:meth:`opt.bn_update(loader, model, device)`. It's meant to be called in
the end of training to use the collected swa running averages and update
the Batch Normalization running activation statistics in the model. It
can also be used to evaluate the SWA running averages during training;
to continue training `swap_swa_sgd` should be called again.
"""
self.swap_swa_sgd()
self.bn_update(loader, model, device)