How to train model from remaining interrupted epoch?

I had save a model upto 7 epochs,

Summary

Epoch: 1 Training Loss: 4.736156 Validation Loss: 5.014328
Validation has decreased
Saving Model…
Epoch: 2 Training Loss: 4.607684 Validation Loss: 5.020730
Epoch: 3 Training Loss: 4.722754 Validation Loss: 4.918604
Validation has decreased
Saving Model…
Epoch: 4 Training Loss: 4.724532 Validation Loss: 4.631703
Validation has decreased
Saving Model…
Epoch: 5 Training Loss: 4.621485 Validation Loss: 4.857414
Epoch: 6 Training Loss: 4.601058 Validation Loss: 5.013983
Epoch: 7 Training Loss: 4.561738 Validation Loss: 4.354510
Validation has decreased
Saving Model…
Epoch: 8 Training Loss: 4.606921 Validation Loss: 4.746311
Epoch: 9 Training Loss: 4.655972 Validation Loss: 4.657127
Epoch: 10 Training Loss: 4.661118 Validation Loss: 4.698534
Epoch: 11 Training Loss: 4.605792 Validation Loss: 4.700062
Epoch: 12 Training Loss: 4.711023 Validation Loss: 4.680906
Epoch: 13 Training Loss: 4.442619 Validation Loss: 4.678782
Epoch: 14 Training Loss: 4.598712 Validation Loss: 4.609448
Epoch: 15 Training Loss: 4.525731 Validation Loss: 4.983483
Epoch: 16 Training Loss: 4.530140 Validation Loss: 4.464461
Epoch: 17 Training Loss: 4.501666 Validation Loss: 4.358377
Epoch: 18 Training Loss: 4.614322 Validation Loss: 4.664511
Epoch: 19 Training Loss: 4.537676 Validation Loss: 4.367801
Epoch: 21 Training Loss: 4.442828 Validation Loss: 4.889397

Unfortunately internet gone.
Now i want to again train interrupted training , As i had saved model using

 state = {'epoch': epoch + 1, 'state_dict': model.state_dict(), 
'optimizer': optimizer.state_dict(), 'loss': loss, }
torch.save(state, 'saved.pt')

How do i train from that saved.pt again after 7 epoch?
Thank you for answer

Maybe my answer here is too basic but this is how I generally do this:

Load the model:
from model_location import model

Instantiate the model:
m = model()

Load the state dictionary required:
m.load_state_dict(torch.load(state_dictionary_saved_at_epoch_7.pt'))

(as you have saved it as saved.pt you might have to preload this dictionary and then replace state_dictionary_saved_at_epoch_7.pt' with whatever is in 'state_dict': model.state_dict() at epoch 7. (perhaps this is the subtelty of your question that most of this answer will be not addressing properly?); you will also have to push the optimizer state dictionary and loss value too but I don’t know how to do that.)

Set the training state of the model m to true (perhaps this is done by default):
m.train(True)

…and then continue training using the model now called m in this example.

Sorry if that’s too basic an answer!

1 Like

Thank you for you explanation :slight_smile: . but im not getting over How to replace
‘state_dictionary_saved_at_epoch_7.pt’ from ‘saved.pt’ and all below, also code example.

Hi!

The answer by @spacemeerkat seems to be the correct way to go (see PyTorch: Saving and Loading Models)

from your_model import model  # Replace your_model with the name of your model class
m = model()  # Create a new model
state = torch.load(saved.pt)  # Load the whole dictionary, as that's what you have saved
m.load_state_dict(state[state_dict]) 
# There is no need for m.train() if you want to continue with training

P.S.
You are always saving the state in the same file. This leads to only the newest model state to be saved.

1 Like

Does it runs previously retrained weights or starts from new.

Yes, loading a saved state dictionary loads pretrained weights so it won’t be starting from new.

Just note that with this method, be careful not to change the model architecture and then try to load the state dictionary to a new model as obviously they will no longer share the same shapes…something that has stung me in the past.

2 Likes

Thank you , for helping. :slight_smile:

1 Like