I have a large dataset to train and short of cloud RAM and disk space (memory). I think one of the approaches to training all the dataset is by creating a checkpoint to save the best model parameter based on validation and likely the last epoch. I will be glad for guidance on implementing this i.e ensuring training continues from the last epoch with the best-saved model parameter from the previous trainig session
Hi @moreshud
You could do something like:
if accuracy_val > max_accuracy_val:
checkpoint = {
'epoch': epoch,
'model_state': model.state_dict(),
'optimizer_state': optimizer.state_dict(),
}
torch.save(checkpoint, 'path/to/folder/filename.pth')
max_accuracy_val = accuracy_val
That way you save the state of the model whenever you have reached a new maximum accuracy on the validation set.
When you want to continue training you can do:
loaded_checkpoint = torch.load('path/to/folder/filename.pth')
loaded_epoch = loaded_checkpoint['epoch']
loaded_model = model() # instantiate your model
loaded_model.load_state_dict(loaded_checkpoint['model_state'])
loaded_optimizer = torch.optim.SGD(loaded_model.parameters(), lr=0, momentum=0) # or whatever optimizer you use
loaded_optimizer.load_state_dict(loaded_checkpoint['optimizer_state'])
You can then continue training with the loaded model and loaded optimizer.
DISCLAIMER: I am relatively new to PyTorch myself. This approach works for me but I cannot guarantee that there are no better options
Anyways, I hope it helps!
All the best
snowe
1 Like
Thanks for the contribution.
1 Like