Combining a loaded trained model with a new model

I trained a classifier and saved the model using:

torch.save(model, "/home/zaianir/Documents/code/tuto/classif/MNIST_model.pth")

I’m trying to train a new classifier on top of the pretrained saved model without its last layer.
I want to train only the parameters of the new added layers (I don’t want to update the saved parameters).

Here is some of my code:

class VAE2(nn.Module):
    def __init__(self):
        super(VAE2, self).__init__()
        self.fc3=nn.Linear(50,20)
        self.fc4=nn.Linear(20,10)
        
    def forward(self, x):    
        x=F.relu(self.fc3(x))
        x=F.relu(self.fc4(x))
        x=F.log_softmax(x)
        return x
    
    
class VAE(nn.Module):
    def __init__(self, VAE1, VAE2):
        super(VAE, self).__init__()
        self.VAE1=VAE1
        self.VAE2=VAE2
        
    def forward(self):
        x=self.VAE1
        x=self.VAE2(x)
        return x

def loss_function(inp, target):
    l=F.nll_loss(inp, target)
    return l


def train(train_dl, model, epoch_nb,lr1):   
    optimizer=torch.optim.Adam(model.parameters(), lr=lr1)
    train_loss1=[]
    for epoch in range(1,epoch_nb):
            model.train()
            train_loss=0.0
            for idx, (data, label) in enumerate(train_dl):
                data, label= data.to(device), label.to(device)
                out=model(data)
                loss=loss_function(out,data)
                train_loss+=loss.item()
                model.zero_grad()
                loss.backward()
                optimizer.step()     
            av_loss= train_loss / len(train_dl.dataset)   
            print('Epoch: {} Average loss: {:.4f}'.format(
          epoch, av_loss))


train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=True, download=True,    
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=128, shuffle=True, **kwargs)
VAE1 = torch.load("/home/zaianir/Documents/code/tuto/classif/MNIST_model.pth")
VAE2_model=VAE2().to(device)
model=VAE(VAE1, VAE2_model)
epoch=30
learning_rate=0.001
train(train_loader, model, epoch, learning_rate)

Here are the different layers of my loaded model (VAE1):
LAYERS

I’m new to pytorch and don’t know how to proceed. Is my approach correct?
Thank you for your help.

There are some issues in your code:

  • in VAE.forward you are not passing x to self.VAE1, so that x will be the submodule not its output
  • I would suggest to save and load the state_dict instead of the complete model as described in the Serialization docs
  • if you would like to remove the last layer from VAE1, you could e.g. replace it with an nn.Identity layer
  • to freeze the base model’s parameters, set their requires_grad flag to False as described in the Finetuning tutorial

I have trained T5ForConditionalGeneration.from_pretrained(MODEL_NAME, return_dict=True) for Summarization task and another for QA task.

I am trying to load the models so I can use their trained encoder and decoders for a different task. The summarizer_model and the qa_model are loaded from model checkpoints. But it throws an error… When using T5ForConditionalGeneration.from_pretrained() the training starts… Is there something wrong with the way I’m using model checkpoints to train further on ?

As far as I can tell the issue is that I need to return_dict=True for the model I’ve loaded from the checkpoint.

Is there a way to return_dict=True for a saved model checkpoint ?

import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict

alpha = 0.5
class CrossAttentionSummarizer(pl.LightningModule):
 def __init__(self):
       super(CrossAttentionSummarizer, self).__init__()
       self.summarizer_model = summarizer_model
       self.qa_encoder = qa_model
       self.multihead_attn = nn.MultiheadAttention(embed_dim=768, num_heads=4, batch_first=True)
       self.linear1 = nn.Linear(1024*768, 512)
       self.linear2 = nn.Linear(512, 2, bias=False)
       self.ce_loss = nn.CrossEntropyLoss()

 def forward(self, question_passage_input_ids, question_passage_attention_mask, question_labels, input_ids, attention_mask, decoder_attention_mask, labels=None):
   summarizer_output = self.summarizer_model(
       input_ids,
       attention_mask=attention_mask,
       labels=labels,
       decoder_attention_mask=decoder_attention_mask
   )

   qa_output = self.qa_encoder(
       question_passage_input_ids, 
       question_passage_attention_mask,
       question_labels
   )

   

   decoder_output = summarizer_output[3]
   encoder_output = qa_output[2]

   multi_attn_output, multi_attn_output_weights = self.multihead_attn(decoder_output, encoder_output, encoder_output)
   lin_output = self.linear1(multi_attn_output.reshape(-1, 1024*768))
   cls_outputs = self.linear2(lin_output)
   cls_outputs = nn.functional.softmax(cls_outputs, dim=1)
   cls_preds = torch.argmax(cls_outputs, dim=1)
   cls_pred_loss = self.ce_loss(cls_outputs, question_labels.type(torch.int64).squeeze(dim=1))
   return summarizer_output.loss, summarizer_output.logits, cls_pred_loss, cls_preds

I got it!!

Have to save the model using save_pretrained() in order to be able to load with the return_dict=true option.

:slight_smile: