Issue with dimension in creating decoder

Here is the decoder block of my autoencoder. I am not able to create the model when using the unflatten layer. There is some issue with input dimension. P.S. : New to Pytorch here.


latent_dims = 2

class Decoder(nn.Module):
    def __init__(self, latent_dims):
        super(Decoder, self).__init__()
        self.linear1 = nn.Linear(latent_dims, 4096)
        self.unflatten = nn.Unflatten(1, (128, 8, 8))
        self.conv_t_2d_1 = nn.ConvTranspose2d(64, 32, 3, stride=2)#, output_padding=2)
        self.conv_t_2d_2 = nn.ConvTranspose2d(32, 16, 3, stride=2)
        self.conv_t_2d_3 = nn.ConvTranspose2d(16, 1, 3, stride=2)
    
    def forward(self, z):
        z = F.relu(self.linear1(z))
        # z = z.view(-1, z.size( 1 ))
        # z = z.view(1, 2, 4096)
        print(z.size())
        z = self.unflatten(z)
        print(z.size())
        z = F.relu(self.conv_t_2d_1(z))
        z = F.relu(self.conv_t_2d_2(z))
        z = torch.sigmoid(self.conv_t_2d_3(z))
        return z#z.reshape((-1, 1, 64, 64))

Could you post the input shapes which are causing an error?

import torch.nn as nn
import torch.nn.functional as F

class Encoder(nn.Module):
    def __init__(self, latent_dims):
        # This part of code contains all the definations 
        # of the stuffs that we are going to use in the 
        # model
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, 3, padding=1) 
        self.batch_norm1 = nn.BatchNorm2d(16) 
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.batch_norm2 = nn.BatchNorm2d(32)
        self.conv3 = nn.Conv2d(32, 64, 3, padding=1)
        self.batch_norm3 = nn.BatchNorm2d(64)
        self.pool = nn.MaxPool2d(2, 2)
        self.flatten = nn.Flatten(start_dim=1)
        self.linear1 = nn.Linear(4096, latent_dims)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.batch_norm1(x)
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.batch_norm2(x)
        x = self.pool(x)
        x = F.relu(self.conv3(x))
        x = self.batch_norm3(x)
        x = self.pool(x)
        x = self.flatten(x)
        return F.softmax(self.linear1(x))


class Decoder(nn.Module):
    def __init__(self, latent_dims):
        super(Decoder, self).__init__()
        self.linear1 = nn.Linear(latent_dims, 4096)
        self.unflatten = nn.Unflatten(1, (128, 8, 8))
        self.conv_t_2d_1 = nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1)
        self.conv_t_2d_2 = nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1)
        self.conv_t_2d_3 = nn.ConvTranspose2d(32, 1, 3, stride=2,  padding=1, output_padding=1)
        # self.conv_2d_1 = nn.Conv2d(16, 1, 3, stride=2, padding=1)
    
    def forward(self, z):
        z = F.relu(self.linear1(z))
        # z = z.view(-1, z.size( 1 ))
        z = z.view(1, 8192)
        # print(z.size())
        # print(z.view(-1, z.shape[-1]).shape[0])
        z = self.unflatten(z)
        # print(z.size())
        z = F.relu(self.conv_t_2d_1(z))
        # print(z.size())
        z = F.relu(self.conv_t_2d_2(z))
        # print(z.size())
        z = self.conv_t_2d_3(z)
        # print(z.size())
        return torch.sigmoid(z)#z.reshape((-1, 1, 64, 64))

class Autoencoder(nn.Module):
    def __init__(self, latent_dims):
        super(Autoencoder, self).__init__()
        self.encoder = Encoder(latent_dims)
        self.decoder = Decoder(latent_dims)

    def forward(self, x):
        z = self.encoder(x)
        return self.decoder(z)


summary(autoencoder,input_size=(1,64,64))

When we change the batch size it doesn’t works, except for batch size 2.

The issue is caused by:

z = z.view(1, 8192)

which won’t work for a variable batch size, since you are hard-coding the shape to [1, 8192].
Use:

z = z.view(z.size(0), -1)

to flatten the feature dimensions and to keep the batch size.
Afterwards, you’ll run into a shape mismatch since nn.Unflatten(1, (128, 8, 8)) is incompatible with a feature size of 4096 (which z will have now).
You could change it to nn.Unflatten(1, (128, 8, 4)) and it would work, but would create an output which differs in size in the spatial dimensions.

2 Likes

You are awesome! I have solved the problem using your help!

1 Like