Load part of big model

Hi,

I have a model saved with torch.save into the file. The file is quite big (say, 100 GB), torch.load is crashing on a nodes with moderate CPU RAM. The good part is that I don’t need all the tensors in CPU memory at once.

Is there a way to customize torch.load so that it doesn’t produce model with all tensors deserialized? Basically, the majority of tensors should be dropped right after each of them being deserialized or not deserialized at all.

The docs suggest map_location callable argument could possibly do the trick, I cannot find examples for that though.

Any help is appreciated!

Thanks,
Max.

I assume you are using the proper approach of storing the state_dict instead of the complete model.
If that’s the case, you could remove the keys you don’t want to store to disk.
After loading this smaller state_dict from disk again, load it to your model with strict=False:

model = nn.Sequential(
    nn.Linear(1, 1),
    nn.Linear(1, 2))

sd = model.state_dict()

# Delete unnecessary keys
del sd['0.weight']
del sd['0.bias']

# Store
torch.save(sd, 'tmp.pt')

# Load
model = nn.Sequential(
    nn.Linear(1, 1),
    nn.Linear(1, 2))
sd = torch.load('tmp.pt')
model.load_state_dict(sd, strict=False)
1 Like

Thank you for your answer! I have to add I don’t have control over saving the model. My input is a file with the saved model (you are right, this is state_dict).

Oh, that’s trickier.
If I’m not mistaken, PyTorch uses pickle to store and load the data.
It might be possible to load only part of the file via pickle, e.g. by using an io.BytesIO object and seek manually.
However, this sounds like a cumbersome approach and you would also need to dive into the serialization code.

Could you load the state_dict once on a machine with enough memory and then store parts of it to use the first approach?

1 Like

This is the workaround I am currently implementing :slight_smile:
Thanks for your help, I will probably use this workaround for some time, things might change (for example, I might get some control over model saving code).