I have a small model and a large model. The large model is built upon the small one so I would like to load the pre-trained small model to pre-set some layers of the large model. However, when I tried to do large_model.load_state_dict(small_model.state_dict()), then there will always be error like Missing key(s) in state_dict:. Is there a way that I could ignore those layers that do not exist in the small model and load those exist?
Use the parameter strict=False with big_model.load_state_dict().
i.e., big_model.load_state_dict(torch.load('file.pth'), strict=False).
import torch
from torch import nn
from torch.autograd import Variable, grad
#define network weights and input
class small_model(nn.Module):
def __init__(self):
super(small_model, self).__init__()
self.linear1 = nn.Linear(3,4,bias=False)
def forward(self,x):
pass
class big_model(nn.Module):
def __init__(self):
super(big_model, self).__init__()
self.linear1 = nn.Linear(3,4,bias=False)
self.linear2 = nn.Linear(4,5,bias=False)
def forward(self,x):
pass
def print_params(model):
for name, param in model.named_parameters():
print(name, param)
# create small model
small = small_model()
print('small model params')
print_params(small)
# save the small model
torch.save(small.state_dict(), 'small.pth')
# create big model
big = big_model()
print('big model params before copying')
print_params(big)
big.load_state_dict(torch.load('small.pth'), strict=False)
assert torch.equal(big.linear1.weight, small.linear1.weight), 'params do not match after copying'
print('big model params after copying')
print_params(big)