Hello I am trying to do inference with a large model which can not fit into my CPU RAM.
Is there any way I can load only a part of the model checkpoint ? Is it possible to load only the layer names from a model and later the weights of specified layers?
For example I am trying to do something like this:
import torch
class MyModel(torch.nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.linear = torch.nn.Linear(10, 10)
self.act = torch.nn.ReLU()
self.linear2 = torch.nn.Linear(10, 10)
self.act2 = torch.nn.Sigmoid()
def forward(self, x):
x = self.linear(x)
x = self.act(x)
x = self.linear2(x)
x = self.act2(x)
return x
class MyLoader():
"""
This class is a custom loader for PyTorch models
that can be used to load only specific layers.
"""
def __init__(self, pth_file):
self.pth_file = pth_file
def get_layer_names(self):
...
def load_layer_by_name(self, names):
out = {}
for name in names:
out[name] = self.load_layer(name)
return out
def load_layer(self, name):
...
if __name__ == "__main__":
m = MyModel()
torch.save(m.state_dict(), 'example.pth')
# This will fail with OOM error
#model_dict = torch.load( 'example.pth', torch.device('cpu') )
#print(model_dict)
# Custom loader call
loader = MyLoader('example.pth')
n = loader.get_layer_names()
print(n)
loader.load_layer_by_name(['linear.weight', 'linear.bias'])
So far I have tried understanding how models are loaded by trying to debug torch.load function but many functions from pickle module are not exposed.
Thank you for your help !