Hello everyone.
Issue
I’ve been wondering whether it’s possible to somehow pickle/unpickle ScriptModule rather than use torch.jit.save
/torch.jit.load
?
import io
import torch
import torch.jit
m = nn.Linear(1, 10).eval()
torch.save(m, io.BytesIO()) # OK
m_ = torch.jit.script(m)
torch.save(m, io.BytesIO()) # Not OK, compiled module does not have a __getstate__ method defined!
Is there any workaround to pickle/ monkey-patch to make picklable a scipted module?
Pre-history
Original question arose from the following situation.
I have a scripted model of a 3D classification network weighted 500 MB, and it takes torch.jit.load
0.8 second to complete.
If to remove (substitute with empty torch.Tensor()
) all parameters and buffers of the scripted module and save it again — new file would weight 0.2 MB, but torch.jit.load
would take 0.45 seconds to complete.
I’m not a sophisticated person in terms of how torch.jit.load
works internally, but 0.45 seems rather slow.
And I thought, if to make it picklable I could store in shared memory and load it faster.
System version
Here’s Python and PyTorch versions.
sys.version_info(major=3, minor=8, micro=13, releaselevel='final', serial=0)
1.10.1