How to add new nodes at the last layer (fully connected layer)?

Hello,

Because the weight shape of new_last_layer wil mismatch the weight shape of old_last_layer, I think you should do some manaully, padding the old_last_layer weight to the same shape of new_last_layer weight with random values, zeros or ones, and then load_state_dict normally.
I wrote a simple demo for you

class LinearNet(nn.Module):
    def __init__(self):
        super(LinearNet, self).__init__()

        self.linear = nn.Linear(5,5,bias=False)
        self.last_linear_layer = nn.Linear(5, 10, bias=False)

    def forward(self, x):

        return self.last_linear_layer(self.linear(x))

class New_LinearNet(nn.Module):
    def __init__(self):
        super(New_LinearNet, self).__init__()

        self.linear = nn.Linear(5,5,bias=False)
        self.last_linear_layer = nn.Linear(5, 15, bias=False)
        
    def forward(self, x):

        return self.last_linear_layer(self.linear(x))

old_model = LinearNet()
random_input = torch.randn(5, 5)
random_target = torch.randn(10,)
criterion = nn.MSELoss()
opt = torch.optim.SGD(old_model.parameters(), lr=0.001)
for i in range(5):
    opt.zero_grad()
    output = old_model(random_input)
    loss = criterion(output, random_target)
    loss.backward()
    opt.step()

torch.save(old_model.state_dict(), 'old_model.pth')
ckpt = torch.load('old_model.pth')
new_part = torch.randn(5,5) # new part of weight matrix
ckpt['last_linear_layer.weight'] = torch.cat([ckpt['last_linear_layer.weight'], new_part], dim=0)

new_model = New_LinearNet()
new_model.load_state_dict(ckpt)
opt_new = torch.optim.SGD(new_model.parameters(), lr=0.001)
random_target_new = torch.randn(15,)
for i in range(5):
    opt_new.zero_grad()
    output = new_model(random_input)
    loss = criterion(output, random_target_new)
    loss.backward()
    opt_new.step()

I am not sure if this snippet meet your needs, so it works please let me know~

1 Like