You can create a dictionary with everything you need and save it using torch.save()
. Example:
checkpoint = {
'epoch': epoch,
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_sched': lr_sched}
torch.save(checkpoint, 'checkpoint.pth')
Then you can load the checkpoint doing checkpoint = torch.load('checkpoint.pth')
More info here: Loading a saved model for continue training