Hello,
Because the weight shape of new_last_layer wil mismatch the weight shape of old_last_layer, I think you should do some manaully, padding the old_last_layer weight to the same shape of new_last_layer weight with random values, zeros or ones, and then load_state_dict normally.
I wrote a simple demo for you
class LinearNet(nn.Module):
def __init__(self):
super(LinearNet, self).__init__()
self.linear = nn.Linear(5,5,bias=False)
self.last_linear_layer = nn.Linear(5, 10, bias=False)
def forward(self, x):
return self.last_linear_layer(self.linear(x))
class New_LinearNet(nn.Module):
def __init__(self):
super(New_LinearNet, self).__init__()
self.linear = nn.Linear(5,5,bias=False)
self.last_linear_layer = nn.Linear(5, 15, bias=False)
def forward(self, x):
return self.last_linear_layer(self.linear(x))
old_model = LinearNet()
random_input = torch.randn(5, 5)
random_target = torch.randn(10,)
criterion = nn.MSELoss()
opt = torch.optim.SGD(old_model.parameters(), lr=0.001)
for i in range(5):
opt.zero_grad()
output = old_model(random_input)
loss = criterion(output, random_target)
loss.backward()
opt.step()
torch.save(old_model.state_dict(), 'old_model.pth')
ckpt = torch.load('old_model.pth')
new_part = torch.randn(5,5) # new part of weight matrix
ckpt['last_linear_layer.weight'] = torch.cat([ckpt['last_linear_layer.weight'], new_part], dim=0)
new_model = New_LinearNet()
new_model.load_state_dict(ckpt)
opt_new = torch.optim.SGD(new_model.parameters(), lr=0.001)
random_target_new = torch.randn(15,)
for i in range(5):
opt_new.zero_grad()
output = new_model(random_input)
loss = criterion(output, random_target_new)
loss.backward()
opt_new.step()
I am not sure if this snippet meet your needs, so it works please let me know~