What @kevinzakka said.
After saving using something like
state = {'epoch': epoch + 1, 'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(), 'losslogger': losslogger, }
torch.save(state, filename)
(losslogger is just something I use to keep track of the loss history; you can replace it with a tensorboard session or remove it)
…you then can re-load the model weights and the state of your optimizer and other things by calling something like
def load_checkpoint(model, optimizer, losslogger, filename='checkpoint.pth.tar'):
# Note: Input model & optimizer should be pre-defined. This routine only updates their states.
start_epoch = 0
if os.path.isfile(filename):
print("=> loading checkpoint '{}'".format(filename))
checkpoint = torch.load(filename)
start_epoch = checkpoint['epoch']
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
losslogger = checkpoint['losslogger']
print("=> loaded checkpoint '{}' (epoch {})"
.format(filename, checkpoint['epoch']))
else:
print("=> no checkpoint found at '{}'".format(filename))
return model, optimizer, start_epoch, losslogger