Pickling ScriptModule

Hello everyone.

Issue

I’ve been wondering whether it’s possible to somehow pickle/unpickle ScriptModule rather than use torch.jit.save/torch.jit.load?

import io

import torch
import torch.jit

m = nn.Linear(1, 10).eval()
torch.save(m, io.BytesIO())  # OK

m_  = torch.jit.script(m)
torch.save(m, io.BytesIO())  # Not OK, compiled module does not have a __getstate__ method defined!

Is there any workaround to pickle/ monkey-patch to make picklable a scipted module?

Pre-history

Original question arose from the following situation.

I have a scripted model of a 3D classification network weighted 500 MB, and it takes torch.jit.load 0.8 second to complete.

If to remove (substitute with empty torch.Tensor()) all parameters and buffers of the scripted module and save it again — new file would weight 0.2 MB, but torch.jit.load would take 0.45 seconds to complete.

I’m not a sophisticated person in terms of how torch.jit.load works internally, but 0.45 seems rather slow.

And I thought, if to make it picklable I could store in shared memory and load it faster.

System version

Here’s Python and PyTorch versions.

sys.version_info(major=3, minor=8, micro=13, releaselevel='final', serial=0)
1.10.1

I don’t think you can fully avoid torch.jit.load - or at least the things it uses - if you want to get a ScriptModule from it.

Best regards

Thomas

That’s a pity. Maybe there are any hacks to somehow pickle ScriptModule once it’s obtained after torch.jit.load you know of?

So my reading of this is:

  • torch.jit.load needs to do a lot of work to create the ScriptModule, quite likely a lot of that is to get to a usable state, (but I did not benchmark it, so there could be inefficiencies), let’s imagine it has two parts “read from disk” and “create objects and reach state”, I would venture that likely, the second part is where the time is spent,
  • pickle is a protocol. It will rely on object creation and such to create the objects it is deserializing,
  • so if there were a way to pickle / unpickle, it would need the second part of the torch.jit.load.

This is roughly why I don’t expect there to be a solution faster than torch.jit.load.

Best regards

Thomas

Yes, this seems very logical. Thanks!