Hi, I am trying to save checkpoints during training using the following code:
info_dict = {
'epoch' : epoch,
'numpy_random_state' : np.random.get_state(),
'torch_random_state' : torch.random.get_rng_state(),
'net_state' : net.module.state_dict(),
'optimizer_state' : optimizer.state_dict(),
'train_hist' : history_plot['train'],
'val_hist' : history_plot['val'],
}
torch.save(info_dict, chkpnts)
And loading them, using:
net.load_state_dict(chkpnt['net_state'])
optimizer.load_state_dict(chkpnt['optimizer_state'])
The code is working fine, however, throws up the UserWarning mentioned above. I am new to PyTorch any help is appreciated.