I have written a class/method for early stopping, but I wanted some clarification on how to load weights to after early stopping. So, we have:
model = Net()
optim=Adam(model.params()) # rough syntax
lrs= <some_lrs>
for epoch in epochs:
# do something and check for validation loss, assuming this has decreased, do:
torch.save(model.state_dict(), "es_ckpt.pt")
# check for early_stopping base don iterations of no_change, assuming no_change, load best weights for model:
stop()
# load best model like so:
model.load_state_dict(torch.load('es_ckpt.pt'))
is this it? Do I need to create a fresh instance of the model or do I simply load the weights to the existing instance like above?
I also presume I do not need optimiser or lrs paramters in this es_ckpt.pt? since I have them already.
I just wnat to make sure there are no gotchas here, since it sounds too simple