Model load when there exists a layer name match

I have a small model and a large model. The large model is built upon the small one so I would like to load the pre-trained small model to pre-set some layers of the large model. However, when I tried to do large_model.load_state_dict(small_model.state_dict()), then there will always be error like Missing key(s) in state_dict:. Is there a way that I could ignore those layers that do not exist in the small model and load those exist?

Thanks!

Use the parameter strict=False with big_model.load_state_dict().
i.e., big_model.load_state_dict(torch.load('file.pth'), strict=False).

import torch
from torch import nn
from torch.autograd import Variable, grad

#define network weights and input
class small_model(nn.Module):
    def __init__(self):
        super(small_model, self).__init__()
        self.linear1 = nn.Linear(3,4,bias=False)
    
    def forward(self,x):
        pass

class big_model(nn.Module):
    def __init__(self):
        super(big_model, self).__init__()
        self.linear1 = nn.Linear(3,4,bias=False)
        self.linear2 = nn.Linear(4,5,bias=False)
    
    def forward(self,x):
        pass

def print_params(model):
    for name, param in model.named_parameters():
        print(name, param)

# create small model
small = small_model()
print('small model params')
print_params(small)

# save the small model
torch.save(small.state_dict(), 'small.pth')

# create big model
big = big_model()

print('big model params before copying')
print_params(big)
big.load_state_dict(torch.load('small.pth'), strict=False)

assert torch.equal(big.linear1.weight, small.linear1.weight), 'params do not match after copying'
print('big model params after copying')
print_params(big)
1 Like