Create new Model from some of layers of already Pre-trained model

class Autoencoder(nn.Module):
def init(self):
super().init()
#self.linear = nn.Linear(input_size, output_size)
self.layer1 = nn.Linear(input_size, 10)
self.layer2 = nn.Linear(10,6)
self.layer3 = nn.Linear(6, output_size)
self.act = nn.ReLU()

def forward(self, xb):
    #out = self.linear(xb)
    xb = F.relu(self.layer1(xb))
    xb = F.relu(self.layer2(xb))
    out = F.sigmoid(self.layer3(xb))
    #out = self.layer3(xb)
    return out

Now can I create new torch model with learned weights of only 2 layers of above model:
self.layer1 = nn.Linear(input_size, 10)
self.layer2 = nn.Linear(10,6)

Hello Anuj, check this link out.

As an illustration, suppose this is the other model you want to create (it has its own layer1 and layer2).

class OtherModel(nn.Module):
    def __init__(self, input_size, output_size):
        super(OtherModel, self).__init__()
        self.layer_new_0 = nn.Linear(input_size, input_size)
        self.layer1 = nn.Linear(input_size, 10)
        self.layer2 = nn.Linear(10,6)
        self.layer_new_1 = nn.Linear(6, output_size)
        self.act = nn.ReLU()

    def forward(self, xb):
        xb = F.relu(self.layer_new_0(xb))
        xb = F.relu(self.layer1(xb))
        xb = F.relu(self.layer2(xb))
        out = F.sigmoid(F.relu(self.layer_new_1(xb)))
        return out

Initiate both an old model and a new one.

model = Autoencoder(10, 3)
model_new = OtherModel(10, 10)

Let’s have a look at the old model’s state (parameters) by using the state_dict method.

print(model.state_dict().keys())
odict_keys(['layer1.weight', 'layer1.bias', 'layer2.weight', 'layer2.bias', 'layer3.weight', 'layer3.bias'])

The parameters we care about are the ones whose names start with layer1. and layer2.

We confirm that the new model’s layer1.weight parameter doesn’t have the same values as the old model layer1.weight parameter yet.

print(torch.allclose(model.layer1.weight, model_new.layer1.weight))
False

Let’s use the load_state_dict method to set the new model’s relevant parameters equal to the old model’s. Here I’m using a dictionary comprehension but you can do it many other ways.

state_dict = {k: model.state_dict()[k] for k in filter(lambda x: x.startswith("layer1.") or x.startswith("layer2."), model.state_dict())}
model_new.load_state_dict(state_dict, strict=False)

Now you can confirm that the layer1.weight parameter has indeed been copied over.

print(torch.allclose(model.layer1.weight, model_new.layer1.weight))
True
1 Like

Hi Andrei_Cristea thank you very much for your time and effort to answer my question.
This solution seems very good.
However meanwhile I was trying to figure out and achieved using below code:

class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(29, 14),
            nn.Tanh(),
            nn.Linear(14, 7),
            nn.LeakyReLU(),
            )
        
        self.decoder = nn.Sequential(
           nn.Linear(7, 7),
           nn.Tanh(),
           nn.Linear(7, 29),
           nn.LeakyReLU()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

model = Autoencoder()

========= another model which will use trained weights of first model(Autoencoder one)========

class Latent(nn.Module):
    def __init__(self):
        super().__init__()
        #self.linear = nn.Linear(input_size, output_size)
        self.layer1 = nn.Linear(29, 14)
        self.layer1.weight=nn.Parameter(model.encoder[0].weight.data)
        self.layer2 = nn.Linear(14,7)
        self.layer2.weight=nn.Parameter(model.encoder[2].weight.data)
        
    def forward(self, xb):
        #out = self.linear(xb)
        xb = F.tanh(self.layer1(xb))
        out = F.relu(self.layer2(xb))
        return out

hidden = Latent()

Could you please have a look ,if this approach is right?

Hi Anuj,

This definitely could work, but just wanted to point out a few things.

First, note that you did match the weights, but not the biases, which might cause an issue if you’re not aware of it:

print(
    torch.allclose(model.encoder[0].weight, hidden.layer1.weight),
    torch.allclose(model.encoder[0].bias,   hidden.layer1.bias)
)
Output:
True False

Second, I think the two models are now sharing same weight tensors in those layers, which means that updating the values in one place will change them in the other place. This might be fine if it’s what you intended, but it’s important to be aware that it’s happening.

To illustrate, suppose we make a copy of the shared tensor, and we confirm that the copy preserves the values:

weight_initial = model.encoder[0].weight.detach().clone()
print(torch.allclose(model.encoder[0].weight, weight_initial))
Output:
True

Then we train the “Latent” model for a few steps:

from torch import optim
criterion = nn.MSELoss()
optimizer = optim.SGD(hidden.parameters(), lr=1e-3)
for _ in range(10):
    optimizer.zero_grad()
    input = torch.randn(1, 29)
    target = torch.randn(1, 7)
    output = hidden(input)  # explicitly updating "latent" model only, not "autoencoder"
    loss = criterion(output, target)
    loss.backward()
    optimizer.step()

Even though we never explicitly manipulated the Autoencoder, you can see its weights have changed, since they were shared with the Latent model:

print(torch.allclose(model.encoder[0].weight, weight_initial))  # "autoencoder" weights have changed
Output:
False

Relatedly, keep in mind that when you train the Latent model, all of its parameters are going to be changed, including the ones that you transferred over from the Autoencoder. If you don’t want this behavior, and you only want to train some parameters (the new ones) but not others (not the transferred ones), you can “freeze” the desired parameters, as described here.

Good luck!
Andrei

1 Like

Just also want to state here that you shouldn’t be using the .data attribute as it’s deprecated and can lead to undefined behavior. Try something like this instead,

self.layer1.weight.copy_(model.encoder[0].weight)
1 Like

Thanks @Andrei_Cristea for such a insightful response…

Thanks @AlphaBetaGamma96 for highlighting issue.