Cutting an existing model and reusing the pre-trained weights after loading from bigger model

I want to use the decoder part of an already existing auto-encoder, but probably what I am doing is not right… Could someone clarify if this is okay or not? Since I am not able to get summary by calling torch summary.


# downloading the model here
! wget https://github.com/Jimut123/simply_junk/raw/main/models/autoencoder.pt

import torch
import torch.nn as nn
import torch.nn.functional as F

import cv2
import pickle
import torch.nn.functional as F
import torch.utils
import torch.distributions
import torchvision
import numpy as np
from torchsummary import summary
from torchvision import datasets, transforms

import matplotlib.pyplot as plt # plotting library
import numpy as np # this module is useful to work with numerical arrays
import pandas as pd 
import random 
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader,random_split
from torch import nn
import torch.optim as optim


use_cuda = torch.cuda.is_available()
print('use_cuda: {}'.format(use_cuda))
device = torch.device("cuda" if use_cuda else "cpu")
print("Device to be used : ",device)


class EncoderV5(nn.Module):
    def __init__(self):
        # This part of code contains all the definations 
        # of the stuffs that we are going to use in the 
        # model
        super(EncoderV5, self).__init__()
        self.conv1 = nn.Conv2d(3,256, 3, padding=1) 
        self.batch_norm1 = nn.BatchNorm2d(256) 
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(256,128,3, padding=1 )
        self.batch_norm2 = nn.BatchNorm2d(128)
        self.conv3 = nn.Conv2d(128,64,3, padding=1)
        self.batch_norm3 = nn.BatchNorm2d(64)
        self.conv4 = nn.Conv2d(64,64,3, padding=1)
        self.batch_norm4 = nn.BatchNorm2d(64)
        #self.flatten = nn.Flatten()

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.batch_norm1(x)
        x = self.pool(x)
        x = self.pool(x)

        x = F.relu(self.conv2(x))
        x = self.batch_norm2(x)
        x = self.pool(x)
        x = self.pool(x)

        x = F.relu(self.conv3(x))
        x = self.batch_norm3(x)
        x = self.pool(x)
        x = self.pool(x)

        x = F.relu(self.conv4(x))
        x = self.batch_norm4(x)
        x = self.pool(x)
        return F.softmax(x, dim = 1)


class DecoderV5(nn.Module):
    def __init__(self):
        super(DecoderV5, self).__init__()
        self.t_conv1 = nn.ConvTranspose2d(64, 64, 4, stride=4)
        self.batch_norm1 = nn.BatchNorm2d(64)
        
        self.t_conv3 = nn.ConvTranspose2d(64, 128, 4, stride=4)
        self.batch_norm2 = nn.BatchNorm2d(128)

        self.t_conv5 = nn.ConvTranspose2d(128, 256, 4, stride=4)
        self.batch_norm3 = nn.BatchNorm2d(256)
        
        self.t_conv7 = nn.ConvTranspose2d(256, 3, 2, stride=2)
      

    def forward(self, x):
        x = F.relu(self.t_conv1(x))
        x = self.batch_norm1(x)

        x = F.relu(self.t_conv3(x))
        x = self.batch_norm2(x)
        
        x = F.relu(self.t_conv5(x))
        x = self.batch_norm3(x)

        x = self.t_conv7(x)
        return F.softmax(x, dim = 1)


class AutoencoderV5(nn.Module):
    def __init__(self):
        super(AutoencoderV5, self).__init__()
        self.encoder = EncoderV5()
        self.decoder = DecoderV5()

    def forward(self, x):
        z = self.encoder(x)
        return self.decoder(z)


# load checkpoint in pytorch
def load_ckp(checkpoint_path, model, model_opt):
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['state_dict'])
    model_opt.load_state_dict(checkpoint['optimizer'])
    return model, model_opt, checkpoint['epoch']


checkpoint_path = '/content/autoencoder.pt'

model = AutoencoderV5()
model = model.to(device)

model_opt = optim.Adam([p for p in model.parameters() if p.requires_grad], lr=1e-04, weight_decay=5e-4)
model, optimizer, epoch = load_ckp(checkpoint_path, model, model_opt)
summary(model, input_size=(3, 128, 128))
print("Loading the saved model...")


# Is this  okay to extract the decoder part of the auto-encoder only?

class Decoder_Custom(nn.Module):
    def __init__(self):
        super(Decoder_Custom, self).__init__()
        # https://github.com/mortezamg63/Accessing-and-modifying-different-layers-of-a-pretrained-model-in-pytorch
        self.decoder_only = list(model.children())[-1]

    def forward(self, x):
        return self.decoder_only

model_dec_only = Decoder_Custom()
model_dec_only = model_dec_only.to(device)

summary(model_dec_only, input_size=(64,1,1))

# code breaks here

Note: I want to extract the decoder by using the pre-trained weights of the autoencoder, intact.

Here is the link to the jupyter notebook on google colab: Google Colab