Runtime Error when loading pytorch model from .pkl

Hi all,

I am using a cloud service that requires my model to be serialized with pickle. I’ve trained a unet model and saved the full model in .pth extension and .pkl but when I try to load the model from the .pkl format I get the following RuntimeError:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-81-14dbf4cf4d26> in <module>()
----> 1 unet = torch.load('unet_bus.pkl')
      2 unet.eval()

~/anaconda3/envs/pytorch/lib/python3.5/site-packages/torch/serialization.py in load(f, map_location, pickle_module)
    356         f = open(f, 'rb')
    357     try:
--> 358         return _load(f, map_location, pickle_module)
    359     finally:
    360         if new_fd:

~/anaconda3/envs/pytorch/lib/python3.5/site-packages/torch/serialization.py in _load(f, map_location, pickle_module)
    532     magic_number = pickle_module.load(f)
    533     if magic_number != MAGIC_NUMBER:
--> 534         raise RuntimeError("Invalid magic number; corrupt file?")
    535     protocol_version = pickle_module.load(f)
    536     if protocol_version != PROTOCOL_VERSION:

RuntimeError: Invalid magic number; corrupt file?

This is how I’m saving and loading the pickle model:

import pickle

pkl_filename = "unet_bus.pkl"  
with open(pkl_filename, 'wb') as file:  
    pickle.dump(model, file)

unet = torch.load('unet_bus.pkl')
unet.eval()

Any help appreciated!

1 Like

Could you try to use torch.save instead of pickle.dump?

1 Like

Nice thanks @ptrblck that worked! Do you know why writing it the former way is an issue?

It seems PyTorch uses a magic number to identify the file format or protocol as seen here. If you save it with another library and try to load it using PyTorch, you’ll encounter this error.

1 Like

Ahh, cool! Didn’t know that. Thanks for your promptness :slight_smile:

How to solve this bug?
Restoring states from the checkpoint path at /home/ww/Coding/AeDet/data/nuscenes/nuscenes_12hz_infos_train.pkl
Traceback (most recent call last):
File “/home/ww/Coding/AeDet/exps/aedet/aedet_lss_r50_256x704_128x128_24e_2key.py”, line 109, in
run_cli()
File “/home/ww/Coding/AeDet/exps/aedet/aedet_lss_r50_256x704_128x128_24e_2key.py”, line 105, in run_cli
main(args)
File “/home/ww/Coding/AeDet/exps/aedet/aedet_lss_r50_256x704_128x128_24e_2key.py”, line 75, in main
trainer.fit(model, ckpt_path=args.ckpt_path)
File “/home/ww/.conda/envs/aedet2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py”, line 771, in fit
self._call_and_handle_interrupt(
File “/home/ww/.conda/envs/aedet2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py”, line 722, in _call_and_handle_interrupt
return self.strategy.launcher.launch(trainer_fn, *args, trainer=self, **kwargs)
File “/home/ww/.conda/envs/aedet2/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/subprocess_script.py”, line 93, in launch
return function(*args, **kwargs)
File “/home/ww/.conda/envs/aedet2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py”, line 812, in _fit_impl
results = self._run(model, ckpt_path=self.ckpt_path)
File “/home/ww/.conda/envs/aedet2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py”, line 1180, in _run
self._restore_modules_and_callbacks(ckpt_path)
File “/home/ww/.conda/envs/aedet2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py”, line 1140, in _restore_modules_and_callbacks
self._checkpoint_connector.resume_start(checkpoint_path)
File “/home/ww/.conda/envs/aedet2/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py”, line 84, in resume_start
self._loaded_checkpoint = self._load_and_validate_checkpoint(checkpoint_path)
File “/home/ww/.conda/envs/aedet2/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py”, line 88, in _load_and_validate_checkpoint
loaded_checkpoint = self.trainer.strategy.load_checkpoint(checkpoint_path)
File “/home/ww/.conda/envs/aedet2/lib/python3.8/site-packages/pytorch_lightning/strategies/strategy.py”, line 316, in load_checkpoint
return self.checkpoint_io.load_checkpoint(checkpoint_path)
File “/home/ww/.conda/envs/aedet2/lib/python3.8/site-packages/pytorch_lightning/plugins/io/torch_plugin.py”, line 85, in load_checkpoint
return pl_load(path, map_location=map_location)
File “/home/ww/.conda/envs/aedet2/lib/python3.8/site-packages/pytorch_lightning/utilities/cloud_io.py”, line 47, in load
return torch.load(f, map_location=map_location)
File “/home/ww/.conda/envs/aedet2/lib/python3.8/site-packages/torch/serialization.py”, line 608, in load
return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
File “/home/ww/.conda/envs/aedet2/lib/python3.8/site-packages/torch/serialization.py”, line 780, in _legacy_load
raise RuntimeError(“Invalid magic number; corrupt file?”)
RuntimeError: Invalid magic number; corrupt file?
Any help appreciated!