Runtime Error when loading pytorch model from .pkl

Hi all,

I am using a cloud service that requires my model to be serialized with pickle. I’ve trained a unet model and saved the full model in .pth extension and .pkl but when I try to load the model from the .pkl format I get the following RuntimeError:

RuntimeError                              Traceback (most recent call last)
<ipython-input-81-14dbf4cf4d26> in <module>()
----> 1 unet = torch.load('unet_bus.pkl')
      2 unet.eval()

~/anaconda3/envs/pytorch/lib/python3.5/site-packages/torch/ in load(f, map_location, pickle_module)
    356         f = open(f, 'rb')
    357     try:
--> 358         return _load(f, map_location, pickle_module)
    359     finally:
    360         if new_fd:

~/anaconda3/envs/pytorch/lib/python3.5/site-packages/torch/ in _load(f, map_location, pickle_module)
    532     magic_number = pickle_module.load(f)
    533     if magic_number != MAGIC_NUMBER:
--> 534         raise RuntimeError("Invalid magic number; corrupt file?")
    535     protocol_version = pickle_module.load(f)
    536     if protocol_version != PROTOCOL_VERSION:

RuntimeError: Invalid magic number; corrupt file?

This is how I’m saving and loading the pickle model:

import pickle

pkl_filename = "unet_bus.pkl"  
with open(pkl_filename, 'wb') as file:  
    pickle.dump(model, file)

unet = torch.load('unet_bus.pkl')

Any help appreciated!

Could you try to use instead of pickle.dump?

Nice thanks @ptrblck that worked! Do you know why writing it the former way is an issue?

It seems PyTorch uses a magic number to identify the file format or protocol as seen here. If you save it with another library and try to load it using PyTorch, you’ll encounter this error.

Ahh, cool! Didn’t know that. Thanks for your promptness :slight_smile: