Torch.save/load does not support partial read

Greetings comrades,

I would expect the following to run smoothly, since it’s a valid pattern with raw pickle:

In[203]: with open('file.pt', 'wb') as f:
    ...:     torch.save(torch.rand(4, 4), f)
    ...:     torch.save(torch.rand(4, 4), f)
    ...:     
In[204]: with open('file.pt', 'rb') as f:
    ...:     print(torch.load(f))
    ...:     print(torch.load(f))
    ...:     
tensor([[ 0.0685,  0.9256,  0.7928,  0.5120],
        [ 0.8782,  0.6241,  0.7041,  0.0264],
        [ 0.3485,  0.4451,  0.7395,  0.4364],
        [ 0.6275,  0.4573,  0.2679,  0.7629]])
Traceback (most recent call last):
  File "/home/clemaire/miniconda-mio/envs/venv/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2963, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-204-07261aa68cb4>", line 3, in <module>
    print(torch.load(f))
  File "/home/clemaire/miniconda-mio/envs/venv/lib/python3.6/site-packages/torch/serialization.py", line 303, in load
    return _load(f, map_location, pickle_module)
  File "/home/clemaire/miniconda-mio/envs/venv/lib/python3.6/site-packages/torch/serialization.py", line 459, in _load
    magic_number = pickle_module.load(f)
_pickle.UnpicklingError: invalid load key, '\x10'.

My use case is being able to read only the metadata associated with a training checkpoint without having to load all the network weights in memory. I would save the file like this:

with open('checkpoint.pt', 'wb') as f:
    torch.save({'epoch': 123, ...}, f)
    torch.save(net.state_dict(), f)

Why doesn’t this work? Do you think it should be supported?

Thanks!

My question is similar in some sense so posting it here. I want to save partial state dict and architecture and be able to load it back. Like saving the weights of only encoder and use that as a part of other model.

Thanks!