Copy weights from only one layer of one model to another model (with different structure)

Hello! I need to pretrain embedding layer of a model in a self-supervised manner and then use this pretrained embedding layer in the other model with a different structure. Both models have embedding layer as the first layer. Is there is a way to transfer weights of this layer from the pretrained model EHR_Embedding() to my other model LSTM_model() ?

Is it enough to just assign the weighs in the following way (as follows from this post: https://discuss.pytorch.org/t/copying-part-of-the-weights/14199?u=maslenkovas )

model.embedding.weight.data = pretrained_model.embedding.weight.data

The classes of both models are below.

Thank you :slight_smile:

class EHR_Embedding(nn.Module):
    def __init__(self, embedding_size, vocab_size=15463, drop=0.1):
        super(EHR_Embedding, self).__init__()
        self.embedding_size = embedding_size
        self.vocab_size = vocab_size

        self.embedding = nn.Embedding(self.vocab_size, self.embedding_size)
        
        self.projection = nn.Sequential(
            nn.ReLU(),
            nn.Linear(in_features=embedding_size, out_features=embedding_size)
        )
        self.drop = nn.Dropout(p=drop)
        
    def forward(self, tensor_demo, tensor_med, tensor_vitals, tensor_labs):   
        batch_size = tensor_med.size()[0]

        # first traansformation
        emb_demo_X = self.drop(self.embedding(tensor_demo.squeeze(1)))
        emb_med_X = self.drop(self.embedding(tensor_med[:,:].squeeze(1)))
        emb_vitals_X = self.drop(self.embedding(tensor_vitals[:,:].squeeze(1)))
        emb_labs_X =  self.drop(self.embedding(tensor_labs[:,:].squeeze(1)))

        projection_demo_X = self.projection(emb_demo_X)
        projection_med_X = self.projection(emb_med_X)
        projection_vitals_X = self.projection(emb_vitals_X)
        projection_labs_X = self.projection(emb_labs_X)

        embedding_X = (emb_demo_X, emb_med_X, emb_vitals_X, emb_labs_X)
        projection_X = (projection_demo_X, projection_med_X, projection_vitals_X, projection_labs_X)

        # second transformation
        emb_demo_Y = self.drop(self.embedding(tensor_demo.squeeze(1)))
        emb_med_Y = self.drop(self.embedding(tensor_med[:,:].squeeze(1)))
        emb_vitals_Y = self.drop(self.embedding(tensor_vitals[:,:].squeeze(1)))
        emb_labs_Y =  self.drop(self.embedding(tensor_labs[:,:].squeeze(1)))

        projection_demo_Y = self.projection(emb_demo_Y)
        projection_med_Y = self.projection(emb_med_Y)
        projection_vitals_Y = self.projection(emb_vitals_Y)
        projection_labs_Y = self.projection(emb_labs_Y)

        embedding_Y = (emb_demo_Y, emb_med_Y, emb_vitals_Y, emb_labs_Y)
        projection_Y = (projection_demo_Y, projection_med_Y, projection_vitals_Y, projection_labs_Y)

        return embedding_X, projection_X, embedding_Y, projection_Y
class LSTM_model(nn.Module):

    def __init__(self, pretrained_model, H=128, max_length=max_length, max_day=7, vocab_size=vocab_size, embedding_size=200):
        super(LSTM_model, self).__init__()

		# Hyperparameters
        self.max_day = max_day
        L = (self.max_day+1) * (256 + 256 + 512) + 1280
        self.H = H
        self.max_length = max_length
        self.embedding_size = embedding_size
        self.vocab_size = vocab_size

        self.embedding = pretrained_model
    

        self.fc_med = nn.Linear(max_length['medications'] * 2 * self.H, 256)  #65,280
        self.fc_vit = nn.Linear(max_length['vitals'] * 2 * self.H, 256)   #51,200
        self.fc_lab = nn.Linear(max_length['lab_tests'] * 2 * self.H, 512) #102,400

        self.lstm_day = nn.LSTM(input_size=embedding_size,
                            hidden_size=self.H,
                            num_layers=1,
                            batch_first=True,
                            bidirectional=True)

        self.fc_1 = nn.Linear((self.max_day+1) * (256 + 256 + 512) + max_length['demographics']*2*H, 2048)

        self.lstm_adm = nn.LSTM(input_size=2048,
                            hidden_size=self.H,
                            num_layers=1,
                            batch_first=True,
                            bidirectional=False)

        self.drop = nn.Dropout(p=0.5)

        self.fc_2 = nn.Linear(self.H, (self.max_day+1))
        
        # self.sigmoid = nn.Sigmoid()


    def forward(self, tensor_demo, tensor_med, tensor_vitals, tensor_labs):

        batch_size = tensor_med.size()[0]
        days = self.max_day + 1

        out_emb_med_demo = self.embedding(tensor_demo.squeeze(1))
        output_lstm_day_demo, _ = self.lstm_day(out_emb_med_demo)
        full_output = output_lstm_day_demo.reshape(batch_size, self.max_length['demographics']* 2 * self.H)


        for d in range(days):
            # embedding layer applied to all tensors
            out_emb_med = self.embedding(tensor_med[:, d, :].squeeze(1))
            out_emb_vitals = self.embedding(tensor_vitals[:, d, :].squeeze(1))
            out_emb_labs =  self.embedding(tensor_labs[:, d, :].squeeze(1))
            # lstm layer applied to embedded tensors
            output_lstm_day_med = self.fc_med(\
                                    self.lstm_day(out_emb_med)[0]\
                                        .reshape(batch_size, max_length['medications'] * 2 * self.H))

            output_lstm_day_vitals = self.fc_vit(\
                                        self.lstm_day(out_emb_vitals)[0]\
                                            .reshape(batch_size,  max_length['vitals'] * 2 * self.H))

            output_lstm_day_labs = self.fc_lab(\
                                    self.lstm_day(out_emb_labs)[0]\
                                        .reshape(batch_size, max_length['lab_tests']* 2 * self.H))
                                        
            # concatenate for all 26 days
            full_output = torch.cat((full_output, \
                                        output_lstm_day_med,\
                                            output_lstm_day_vitals,\
                                                output_lstm_day_labs), dim=1)
        
        # print('full_output size: ', full_output.size())
        output = self.fc_1(full_output)
        output, _ = self.lstm_adm(output)
        output = self.drop(output)
        output = self.fc_2(output)
        output = torch.squeeze(output, 1)
        # if self.criterion == 'BCELoss':
        #     output = self.sigmoid(output)

        return output

It seems as if the self.embedding attribute isn’t the same layer in both models:

# A
self.embedding = nn.Embedding(self.vocab_size, self.embedding_size)

# B
self.embedding = pretrained_model

so I would guess that assigning the .weight attribute directly from A to B wouldn’t work.
However, if you make sure the self.embedding points to an nn.Embedding layer of the same dimensions, then you could copy the data via:

with torch.no_grad():
    modelA.embedding.weight.copy_(modelB.embedding.weight)
2 Likes

You right, I changed the architecture so now both embedding layers are the same in both models.

Is there any difference between

model.embedding.weight.data = pretrained_model.embedding.weight.data

and your solution? If so, which one should I use?

class EHR_Embedding(nn.Module):
    def __init__(self, embedding_size, vocab_size=15463, drop=0.1):
        super(EHR_Embedding, self).__init__()
        self.embedding_size = embedding_size
        self.vocab_size = vocab_size

        self.embedding = nn.Embedding(self.vocab_size, self.embedding_size)
        
        self.projection = nn.Sequential(
            nn.ReLU(),
            nn.Linear(in_features=embedding_size, out_features=embedding_size)
        )
        self.drop = nn.Dropout(p=drop)
        
    def forward(self, tensor_demo, tensor_med, tensor_vitals, tensor_labs):   
        batch_size = tensor_med.size()[0]

        # first traansformation
        emb_demo_X = self.drop(self.embedding(tensor_demo.squeeze(1)))
        emb_med_X = self.drop(self.embedding(tensor_med[:,:].squeeze(1)))
        emb_vitals_X = self.drop(self.embedding(tensor_vitals[:,:].squeeze(1)))
        emb_labs_X =  self.drop(self.embedding(tensor_labs[:,:].squeeze(1)))

        projection_demo_X = self.projection(emb_demo_X)
        projection_med_X = self.projection(emb_med_X)
        projection_vitals_X = self.projection(emb_vitals_X)
        projection_labs_X = self.projection(emb_labs_X)

        embedding_X = (emb_demo_X, emb_med_X, emb_vitals_X, emb_labs_X)
        projection_X = (projection_demo_X, projection_med_X, projection_vitals_X, projection_labs_X)

        # second transformation
        emb_demo_Y = self.drop(self.embedding(tensor_demo.squeeze(1)))
        emb_med_Y = self.drop(self.embedding(tensor_med[:,:].squeeze(1)))
        emb_vitals_Y = self.drop(self.embedding(tensor_vitals[:,:].squeeze(1)))
        emb_labs_Y =  self.drop(self.embedding(tensor_labs[:,:].squeeze(1)))

        projection_demo_Y = self.projection(emb_demo_Y)
        projection_med_Y = self.projection(emb_med_Y)
        projection_vitals_Y = self.projection(emb_vitals_Y)
        projection_labs_Y = self.projection(emb_labs_Y)

        embedding_Y = (emb_demo_Y, emb_med_Y, emb_vitals_Y, emb_labs_Y)
        projection_Y = (projection_demo_Y, projection_med_Y, projection_vitals_Y, projection_labs_Y)

        return embedding_X, projection_X, embedding_Y, projection_Y

class LSTM_model(nn.Module):

    def __init__(self, max_length, pred_window, vocab_size, H=128,  max_day=7, embedding_size=200):
        super(LSTM_model, self).__init__()

		# Hyperparameters
        self.max_day = max_day
        self.pred_window = pred_window
        L = (self.max_day+1) * (256 + 256 + 512) + 1280
        self.H = H
        self.max_length = max_length
        self.embedding_size = embedding_size
        self.vocab_size = vocab_size

        # self.embedding = pretrained_model
        self.embedding = nn.Embedding(self.vocab_size, self.embedding_size)
    

        self.fc_med = nn.Linear(max_length['medications'] * 2 * self.H, 256)  #65,280
        self.fc_vit = nn.Linear(max_length['vitals'] * 2 * self.H, 256)   #51,200
        self.fc_lab = nn.Linear(max_length['lab_tests'] * 2 * self.H, 512) #102,400

        self.lstm_day = nn.LSTM(input_size=embedding_size,
                            hidden_size=self.H,
                            num_layers=1,
                            batch_first=True,
                            bidirectional=True)

        self.fc_1 = nn.Linear((self.max_day - self.pred_window) * (256 + 256 + 512) + max_length['demographics']*2*H, 2048)

        self.lstm_adm = nn.LSTM(input_size=2048,
                            hidden_size=self.H,
                            num_layers=1,
                            batch_first=True,
                            bidirectional=False)

        self.drop = nn.Dropout(p=0.5)

        self.fc_2 = nn.Linear(self.H, (self.max_day - self.pred_window))
        
        # self.sigmoid = nn.Sigmoid()


    def forward(self, tensor_demo, tensor_med, tensor_vitals, tensor_labs):

        batch_size = tensor_med.size()[0]
        days = self.max_day

        out_emb_med_demo = self.embedding(tensor_demo.squeeze(1))
        output_lstm_day_demo, _ = self.lstm_day(out_emb_med_demo)
        full_output = output_lstm_day_demo.reshape(batch_size, self.max_length['demographics']* 2 * self.H)


        for d in range(days - self.pred_window):
            # embedding layer applied to all tensors
            out_emb_med = self.embedding(tensor_med[:, d, :].squeeze(1))
            out_emb_vitals = self.embedding(tensor_vitals[:, d, :].squeeze(1))
            out_emb_labs =  self.embedding(tensor_labs[:, d, :].squeeze(1))
            # lstm layer applied to embedded tensors
            output_lstm_day_med = self.fc_med(\
                                    self.lstm_day(out_emb_med)[0]\
                                        .reshape(batch_size, max_length['medications'] * 2 * self.H))

            output_lstm_day_vitals = self.fc_vit(\
                                        self.lstm_day(out_emb_vitals)[0]\
                                            .reshape(batch_size,  max_length['vitals'] * 2 * self.H))

            output_lstm_day_labs = self.fc_lab(\
                                    self.lstm_day(out_emb_labs)[0]\
                                        .reshape(batch_size, max_length['lab_tests']* 2 * self.H))
                                        
            # concatenate for all * days
            full_output = torch.cat((full_output, \
                                        output_lstm_day_med,\
                                            output_lstm_day_vitals,\
                                                output_lstm_day_labs), dim=1)
        
        # print('full_output size: ', full_output.size())
        output = self.fc_1(full_output)
        output, _ = self.lstm_adm(output)
        output = self.drop(output)
        output = self.fc_2(output)
        output = torch.squeeze(output, 1)
        # if self.criterion == 'BCELoss':
        #     output = self.sigmoid(output)

        return output

Don’t use the deprecated .data attribute as it can yield unwanted effects and stick to my posted approach.

3 Likes