GAN => maping a random noise to an output space!

Hi Guys,
I’ve some difficulties grasping the concept; that’s why I need your feedback!

I should map a random noise of shape (100, 1, 1) to an image tensor of shape (3, 64, 64).
The generator to implement consists of 5 transposed convolutional layers:

To use is 512, 256, 128, 64, and 3 output channels for these layers.

Use a kernel size of 4.

Use a padding of 0 in the first transposed convolution and a padding of 1 for

all subsequent transposed convolutions.

Do not use a bias.

Use a stride of 1 for the first transposed convolution and a stride of 2 for all subsequent transposed convolutions.

Apply batch normalization after each transposed convolutional layer except for the last one.

Apply a ReLU activation function after each batch normalization layer. Use the torch.nn.ReLU function for this.

Apply a hyperbolic tangent function after the last transposed convolutional layer.

Return the output after this function.

My Implementation:

class Generator(nn.Module):
def init(self):
super(Generator, self).init()
self.main = nn.Sequential(

      #-----------------------------------------------
      #  1.Layer
      #-----------------------------------------------
      nn.ConvTranspose2d(in_channels=100,out_channels=512,
      kernel_size=(4,    4),stride=1,padding=0),
      nn.BatchNorm2d(512),
      nn.ReLU(inplace = True),
        #-----------------------------------------------
        #  2.Layer
        #-----------------------------------------------
     nn.ConvTranspose2d(in_channels=512,out_channels=256,
     kernel_size=(4, 4),stride=2, padding=1),
     nn.BatchNorm2d(256),
     nn.ReLU(inplace = True),
     #-----------------------------------------------
     #. 3.Layer
     #-----------------------------------------------
     nn.ConvTranspose2d(in_channels=256,out_channels=128,
     kernel_size=(4, 4),stride=2,padding=1),
     nn.BatchNorm2d(128),
     nn.ReLU(inplace = True),
     #-----------------------------------------------
     #  4.Layer
     #-----------------------------------------------
     nn.ConvTranspose2d(in_channels=128,out_channels=64,
     kernel_size=(4, 4),stride=2,padding=1),
     nn.BatchNorm2d(64),
     nn.ReLU(inplace = True),

   #-----------------------------------------------
  #  5.Layer
  #-----------------------------------------------
   nn.ConvTranspose2d(in_channels=64,out_channels=3*64*64,
   kernel_size=(4, 4),stride=2,padding=1),
   nn.Tanh()

)

def forward(self, data):
return self.main(data)

I want to know if the implementation is correct. Thanks

almost there, here is a commented version with the needed slight changes

import torch 
from torch import nn

class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        

        # using a dict for all layers other than the first one, to reduce redundancy
        # it's the same as using `kernel_size=4, stride=2, padding=1, bias=False` in each layer
        sub_seq_kwargs = {
            "kernel_size": 4,
            "stride": 2,
            "padding": 1,
            "bias": False
        }

        self.main = nn.Sequential(
            # layer 1
            # by default pytorch include a bias for conv layers, so we'll need to set that to False
            nn.ConvTranspose2d(100, 512, kernel_size=4, stride=1, padding=0, bias=False), # => (bs, 512, 4, 4)
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),

            # layer 2
            nn.ConvTranspose2d(512, 256, **sub_seq_kwargs), # => (bs, 256, 8, 8)
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),

            # layer 3
            nn.ConvTranspose2d(256, 128, **sub_seq_kwargs), # => (bs, 128, 16, 16)
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),

            # layer 4
            nn.ConvTranspose2d(128, 64, **sub_seq_kwargs), # => (bs, 64, 32, 32)
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),

            # layer 5
            # in_channels & out_channels only affect the C dim of your tensor
            # so to get a (bs, 3, 64, 64) output we can use out_channels=3
            # spatial dims (height & width) are augmented by the stride & padding of the transposed conv layer
            nn.ConvTranspose2d(64, 3, **sub_seq_kwargs), # => (bs, 3, 64, 64)
            nn.Tanh()
        )

    def forward(self, data):
        return self.main(data)


g = Generator()
print(g(torch.randn(1, 100, 1, 1)).size())
2 Likes

Many Thanks for the prompt Feedback!

1 Like