Loading a specific layer from checkpoint

Is that possible to load one specific layer of the model from the pretrained checkpoint?

3 Likes

If you just want to load a single layer, this code should work fine:

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.fc1 = nn.Linear(10, 10)
        self.fc2 = nn.Linear(20, 20)
        self.act = nn.ReLU()
        
    def forward(self, x):
        x = self.act(self.fc1(x))
        x = self.fc2(x)
        return x

model = MyModel()
torch.save(model.state_dict(), 'tmp.pth')

state_dict = torch.load('tmp.pth')
model = MyModel()
with torch.no_grad():
    model.fc1.weight.copy_(state_dict['fc1.weight'])
    model.fc1.bias.copy_(state_dict['fc1.bias'])
8 Likes

Thanks a lot! That’s exactly what I was looking for.

1 Like

I have one other question. I am using this method to load the layers. I see two weird things.
One is that loading one weight vs loading 8 weights don’t have much difference in terms of processing time. The second thing is that by changing the batch size, I see different computation time on loading parameters which should not be related. Do you have any idea on these issues?

Could you post some (pseudo) code to see, how you’ve measured the timings?

Here is the psudue code of what I do:
Let’s say I am loading only three layers of the network and each has two sublayers called l1 and l2:
for i in range (3):
self.layer[i]…l1.weight.copy_(state_dict[“fc_{}.l1.weight”.format(i)], non_blocking=True)
self.layer[i].l2.weight.copy_(state_dict[“fc_{}.l2.weight”.format(i)], non_blocking=True)

This should not be related to batch size, but with different batch sizes, I get different timing. I use time.time from python time package to calculate the timing. Also, I am copying from cpu to gpu (my state_dict is on cpu and my model is on gpu).

Where is the batch size used in your code?
As you said, the batch size should be unrelated to the copying of parameters.
Since you are using non_blocking=True, make sure to call torch.cuda.synchronize() before starting and stopping the timer to get valid results.

Sorry for late response. Problem was cuda synchronization which was showing the wrong timing. It is solved by adding torch.cuda.synchronizae()