Error loading checkpoints using the sample code online

Hi! I’ve been trying to use the code provided here to save and load model checkpoints:

torch.save({
‘epoch’: epoch,
‘model_state_dict’: model.state_dict(),
‘optimizer_state_dict’: optimizer.state_dict(),
‘loss’: train_loss,
}, PATH)

But I am getting the following error at loading:

checkpoint = torch.load(DATA_ROOT_DIR + “/” + file_name)


UnpicklingError Traceback (most recent call last)
Input In [10], in <cell line: 3>()
1 model = JARVISTransformer_channelInt(d_model = 4 + 12, nhead = 2, dim_feedforward = 64, num_layers = 2,dropout = 0.5,extra_feats = 12)
----> 3 checkpoint = torch.load(DATA_ROOT_DIR + “/” + file_name)
4 model.load_state_dict(checkpoint[‘model_state_dict’])
6 model.eval()

File ~/.conda/envs/cgrds/lib/python3.10/site-packages/torch/serialization.py:795, in load(f, map_location, pickle_module, weights_only, **pickle_load_args)
793 except RuntimeError as e:
794 raise pickle.UnpicklingError(UNSAFE_MESSAGE + str(e)) from None
→ 795 return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)

File ~/.conda/envs/cgrds/lib/python3.10/site-packages/torch/serialization.py:1002, in _legacy_load(f, map_location, pickle_module, **pickle_load_args)
996 if not hasattr(f, ‘readinto’) and (3, 8, 0) <= sys.version_info < (3, 8, 2):
997 raise RuntimeError(
998 "torch.load does not work with file-like objects that do not implement readinto on Python 3.8.0 and 3.8.1. "
999 f"Received object of type "{type(f)}". Please update to Python 3.8.2 or newer to restore this "
1000 “functionality.”)
→ 1002 magic_number = pickle_module.load(f, **pickle_load_args)
1003 if magic_number != MAGIC_NUMBER:
1004 raise RuntimeError(“Invalid magic number; corrupt file?”)

UnpicklingError: invalid load key, ‘{’.

I haven’t seen any similar issue online, have any of you found something similar? Is it possible to load the model at all? And where’s the error in saving? Thanks!

I’m not exactly sure why the error is raised in your code but pickle would fail to load the file if it was stored in another format.
E.g. this small example shows the same error:

model = nn.Linear(10, 10)
sd = model.state_dict()

with open("test.pt", "w") as f:
    f.write("{'model_state_dict': sd}")
    
torch.load("test.pt")
# UnpicklingError: invalid load key, '{'.

and as you can see I’ve stored the state_dict as a string.
Could you check if the file you are trying to load was indeed created via torch.save?