Greetings comrades,
I would expect the following to run smoothly, since it’s a valid pattern with raw pickle
:
In[203]: with open('file.pt', 'wb') as f:
...: torch.save(torch.rand(4, 4), f)
...: torch.save(torch.rand(4, 4), f)
...:
In[204]: with open('file.pt', 'rb') as f:
...: print(torch.load(f))
...: print(torch.load(f))
...:
tensor([[ 0.0685, 0.9256, 0.7928, 0.5120],
[ 0.8782, 0.6241, 0.7041, 0.0264],
[ 0.3485, 0.4451, 0.7395, 0.4364],
[ 0.6275, 0.4573, 0.2679, 0.7629]])
Traceback (most recent call last):
File "/home/clemaire/miniconda-mio/envs/venv/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2963, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-204-07261aa68cb4>", line 3, in <module>
print(torch.load(f))
File "/home/clemaire/miniconda-mio/envs/venv/lib/python3.6/site-packages/torch/serialization.py", line 303, in load
return _load(f, map_location, pickle_module)
File "/home/clemaire/miniconda-mio/envs/venv/lib/python3.6/site-packages/torch/serialization.py", line 459, in _load
magic_number = pickle_module.load(f)
_pickle.UnpicklingError: invalid load key, '\x10'.
My use case is being able to read only the metadata associated with a training checkpoint without having to load all the network weights in memory. I would save the file like this:
with open('checkpoint.pt', 'wb') as f:
torch.save({'epoch': 123, ...}, f)
torch.save(net.state_dict(), f)
Why doesn’t this work? Do you think it should be supported?
Thanks!