I want to use the decoder part of an already existing auto-encoder, but probably what I am doing is not right… Could someone clarify if this is okay or not? Since I am not able to get summary by calling torch summary.
# downloading the model here
! wget https://github.com/Jimut123/simply_junk/raw/main/models/autoencoder.pt
import torch
import torch.nn as nn
import torch.nn.functional as F
import cv2
import pickle
import torch.nn.functional as F
import torch.utils
import torch.distributions
import torchvision
import numpy as np
from torchsummary import summary
from torchvision import datasets, transforms
import matplotlib.pyplot as plt # plotting library
import numpy as np # this module is useful to work with numerical arrays
import pandas as pd
import random
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader,random_split
from torch import nn
import torch.optim as optim
use_cuda = torch.cuda.is_available()
print('use_cuda: {}'.format(use_cuda))
device = torch.device("cuda" if use_cuda else "cpu")
print("Device to be used : ",device)
class EncoderV5(nn.Module):
def __init__(self):
# This part of code contains all the definations
# of the stuffs that we are going to use in the
# model
super(EncoderV5, self).__init__()
self.conv1 = nn.Conv2d(3,256, 3, padding=1)
self.batch_norm1 = nn.BatchNorm2d(256)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(256,128,3, padding=1 )
self.batch_norm2 = nn.BatchNorm2d(128)
self.conv3 = nn.Conv2d(128,64,3, padding=1)
self.batch_norm3 = nn.BatchNorm2d(64)
self.conv4 = nn.Conv2d(64,64,3, padding=1)
self.batch_norm4 = nn.BatchNorm2d(64)
#self.flatten = nn.Flatten()
def forward(self, x):
x = F.relu(self.conv1(x))
x = self.batch_norm1(x)
x = self.pool(x)
x = self.pool(x)
x = F.relu(self.conv2(x))
x = self.batch_norm2(x)
x = self.pool(x)
x = self.pool(x)
x = F.relu(self.conv3(x))
x = self.batch_norm3(x)
x = self.pool(x)
x = self.pool(x)
x = F.relu(self.conv4(x))
x = self.batch_norm4(x)
x = self.pool(x)
return F.softmax(x, dim = 1)
class DecoderV5(nn.Module):
def __init__(self):
super(DecoderV5, self).__init__()
self.t_conv1 = nn.ConvTranspose2d(64, 64, 4, stride=4)
self.batch_norm1 = nn.BatchNorm2d(64)
self.t_conv3 = nn.ConvTranspose2d(64, 128, 4, stride=4)
self.batch_norm2 = nn.BatchNorm2d(128)
self.t_conv5 = nn.ConvTranspose2d(128, 256, 4, stride=4)
self.batch_norm3 = nn.BatchNorm2d(256)
self.t_conv7 = nn.ConvTranspose2d(256, 3, 2, stride=2)
def forward(self, x):
x = F.relu(self.t_conv1(x))
x = self.batch_norm1(x)
x = F.relu(self.t_conv3(x))
x = self.batch_norm2(x)
x = F.relu(self.t_conv5(x))
x = self.batch_norm3(x)
x = self.t_conv7(x)
return F.softmax(x, dim = 1)
class AutoencoderV5(nn.Module):
def __init__(self):
super(AutoencoderV5, self).__init__()
self.encoder = EncoderV5()
self.decoder = DecoderV5()
def forward(self, x):
z = self.encoder(x)
return self.decoder(z)
# load checkpoint in pytorch
def load_ckp(checkpoint_path, model, model_opt):
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['state_dict'])
model_opt.load_state_dict(checkpoint['optimizer'])
return model, model_opt, checkpoint['epoch']
checkpoint_path = '/content/autoencoder.pt'
model = AutoencoderV5()
model = model.to(device)
model_opt = optim.Adam([p for p in model.parameters() if p.requires_grad], lr=1e-04, weight_decay=5e-4)
model, optimizer, epoch = load_ckp(checkpoint_path, model, model_opt)
summary(model, input_size=(3, 128, 128))
print("Loading the saved model...")
# Is this okay to extract the decoder part of the auto-encoder only?
class Decoder_Custom(nn.Module):
def __init__(self):
super(Decoder_Custom, self).__init__()
# https://github.com/mortezamg63/Accessing-and-modifying-different-layers-of-a-pretrained-model-in-pytorch
self.decoder_only = list(model.children())[-1]
def forward(self, x):
return self.decoder_only
model_dec_only = Decoder_Custom()
model_dec_only = model_dec_only.to(device)
summary(model_dec_only, input_size=(64,1,1))
# code breaks here
Note: I want to extract the decoder by using the pre-trained weights of the autoencoder, intact.
Here is the link to the jupyter notebook on google colab: Google Colab