Saving and Loading Model Checkpoint - Epoch & Validation

I have a large dataset to train and short of cloud RAM and disk space (memory). I think one of the approaches to training all the dataset is by creating a checkpoint to save the best model parameter based on validation and likely the last epoch. I will be glad for guidance on implementing this i.e ensuring training continues from the last epoch with the best-saved model parameter from the previous trainig session

Hi @moreshud

You could do something like:

if accuracy_val > max_accuracy_val:

            checkpoint = {
                'epoch': epoch,
                'model_state': model.state_dict(),
                'optimizer_state': optimizer.state_dict(),
            }
            torch.save(checkpoint, 'path/to/folder/filename.pth')

            max_accuracy_val = accuracy_val

That way you save the state of the model whenever you have reached a new maximum accuracy on the validation set.

When you want to continue training you can do:

loaded_checkpoint = torch.load('path/to/folder/filename.pth')
loaded_epoch = loaded_checkpoint['epoch']
loaded_model = model() # instantiate your model
loaded_model.load_state_dict(loaded_checkpoint['model_state'])
loaded_optimizer = torch.optim.SGD(loaded_model.parameters(), lr=0, momentum=0) # or whatever optimizer you use
loaded_optimizer.load_state_dict(loaded_checkpoint['optimizer_state'])

You can then continue training with the loaded model and loaded optimizer.

DISCLAIMER: I am relatively new to PyTorch myself. This approach works for me but I cannot guarantee that there are no better options :wink:

Anyways, I hope it helps!

All the best
snowe

1 Like

Thanks for the contribution.

1 Like