How can I load/read only part of a model checkpoint?

Hello I am trying to do inference with a large model which can not fit into my CPU RAM.
Is there any way I can load only a part of the model checkpoint ? Is it possible to load only the layer names from a model and later the weights of specified layers?
For example I am trying to do something like this:

import torch 

class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.linear = torch.nn.Linear(10, 10)
        self.act = torch.nn.ReLU()
        self.linear2 = torch.nn.Linear(10, 10)
        self.act2 = torch.nn.Sigmoid()
    
    def forward(self, x):
        x = self.linear(x)
        x = self.act(x)
        x = self.linear2(x)
        x = self.act2(x)
        return x

class MyLoader():
    """
    This class is a custom loader for PyTorch models
    that can be used to load only specific layers. 
    """
    def __init__(self, pth_file):
        self.pth_file = pth_file

    def get_layer_names(self):
        ...

    def load_layer_by_name(self, names):
        out = {}
        for name in names:
            out[name] = self.load_layer(name)
            return out 

    def load_layer(self, name):
        ...


if __name__ == "__main__":
    m = MyModel()
    torch.save(m.state_dict(), 'example.pth')

    # This will fail with OOM error
    #model_dict = torch.load( 'example.pth', torch.device('cpu') )
    #print(model_dict)

    # Custom loader call
    loader = MyLoader('example.pth')
    n = loader.get_layer_names()
    print(n)
    loader.load_layer_by_name(['linear.weight', 'linear.bias'])

So far I have tried understanding how models are loaded by trying to debug torch.load function but many functions from pickle module are not exposed.

Thank you for your help ! :slight_smile:

I’m not sure I completely understand your use case but it seems you would like to load only a subset of the model at once and delete the rest since your host RAM is limited?
You could technically only load a specific layer from the state_dict by accessing its keys, but you would then need to delete the rest of the model, which could generally cause a large bottleneck since you would be loading from your drive in every iteration.

Yes, I want to load only a subset of the model, but torch.load will generate a OOM error for very large models. Another way of looking at it is that I want to load just a shard from HDD to RAM and run that.
I understand that there could be a bottleneck, but that could be alleviated if the subsets are executed in a network made computers with less RAM than model size.