Problem with GAN(Pip2Pix) discriminator and generator loss

Hi I am trying to convert pip2pix GAN from Keras to Pytoch. I copied the discriminator and generator architecture but I ve got problem with loss of discriminator, it is decreasing faster than in keras and also generator loss is not decrasing as in keras (It stucks at 8, in keras it goes lower than 2). I will appreciate any help, because I am stuck on that for o long time.

Here is my pytorch discriminator:

class Discriminator(nn.Module):
    def __init__(self, file_shape: tuple, output_filters=8):
        super(Discriminator, self).__init__()

        self.model = Sequential(
            *self.build_block(file_shape[0] * 2, output_filters, normalization=False),
            *self.build_block(output_filters, output_filters * 2),
            *self.build_block(output_filters * 2, output_filters * 4),
            *self.build_block(output_filters * 4, output_filters * 8),

            ZeroPad2d((0, 0, 1, 0)),
            Conv2d(output_filters * 8, 1, kernel_size=(4, 1), padding=(1, 0), stride=1),
        )

    def build_block(self, in_filters: int, out_filters: int, normalization=True):
        layers = [Conv2d(in_filters, out_filters, kernel_size=(4, 1), stride=(2, 1), padding=(1, 0)),
                  LeakyReLU(0.2, inplace=True)]

        if normalization:
            layers.append(BatchNorm2d(num_features=out_filters, momentum=0.2, eps=1e-3))

        return layers

    def forward(self, input_a, input_b):
        img_input = cat((input_a, input_b), 1)

        return self.model(img_input).double()

Here is my pytorch generator:

class UNetDown(nn.Module):
    def __init__(self, input_size: int, output_filters: int, normalize=True):
        super(UNetDown, self).__init__()

        self.model = Sequential(
            Conv2d(input_size, output_filters, kernel_size=(4, 1), padding=(1, 0), stride=(2, 1), bias=False),
            LeakyReLU(0.2)
        )

        if normalize:
            self.model.add_module("BatchNorm2d", BatchNorm2d(output_filters, momentum=0.2, eps=1e-3))

    def forward(self, x):
        return self.model(x)


class UNetUp(nn.Module):
    def __init__(self, input_size: int, output_filters: int, dropout=0.0):
        super(UNetUp, self).__init__()

        self.model = Sequential(
            Upsample(scale_factor=(2, 1)),
            ZeroPad2d((0, 0, 1, 0)),
            Conv2d(input_size, output_filters, kernel_size=(4, 1), stride=1, padding=(1, 0), bias=False),
            ReLU(inplace=True),
            BatchNorm2d(output_filters, momentum=0.2, eps=1e-3),
        )

        if dropout:
            self.model.add_module("Dropout", Dropout(dropout))

    def forward(self, layer, skip_input):
        layer = self.model(layer)
        layer = cat((layer, skip_input), 1)

        return layer


"""
Implementation based on UNet generator
"""


class Generator(nn.Module):
    def __init__(self, file_shape: tuple, output_filters=8, output_channels=2):
        super(Generator, self).__init__()

        # DownSampling
        self.down1 = UNetDown(file_shape[0], output_filters, normalize=False)
        self.down2 = UNetDown(output_filters, output_filters * 2)
        self.down3 = UNetDown(output_filters * 2, output_filters * 4)
        self.down4 = UNetDown(output_filters * 4, output_filters * 8)
        self.down5 = UNetDown(output_filters * 8, output_filters * 8)
        self.down6 = UNetDown(output_filters * 8, output_filters * 8)
        self.down7 = UNetDown(output_filters * 8, output_filters * 8)

        # UpSampling
        self.up1 = UNetUp(output_filters * 8, output_filters * 8)
        self.up2 = UNetUp(output_filters * 16, output_filters * 8)
        self.up3 = UNetUp(output_filters * 16, output_filters * 8)
        self.up4 = UNetUp(output_filters * 16, output_filters * 4)
        self.up5 = UNetUp(output_filters * 8, output_filters * 2)
        self.up6 = UNetUp(output_filters * 4, output_filters)

        self.last = nn.Sequential(
            Upsample(scale_factor=(2, 4)),
            ZeroPad2d((0, 0, 1, 0)),
            Conv2d(output_filters * 2, output_channels, kernel_size=4, stride=1, padding=(1, 0)),
            Sigmoid(),
        )

    def forward(self, x):
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)
        d7 = self.down7(d6)

        u1 = self.up1(d7, d6)
        u2 = self.up2(u1, d5)
        u3 = self.up3(u2, d4)
        u4 = self.up4(u3, d3)
        u5 = self.up5(u4, d2)
        u6 = self.up6(u5, d1)

        return self.last(u6)

Here is my Keras discriminator:

    def build_discriminator(self):

        def d_layer(layer_input, filters, f_size=4, bn=True):
            """Discriminator layer"""

            d = Conv2D(filters, kernel_size=(f_size, 1), strides=(2, 1), padding='same')(layer_input)
            d = LeakyReLU(alpha=0.2)(d)
            if bn:
                d = BatchNormalization(momentum=0.8)(d)
            return d

        file_A = Input(shape=self.file_shape)
        file_B = Input(shape=self.file_shape)

        # Concatenate file and conditioning file by channels to produce input

        combined_files = Concatenate(axis=-1)([file_A, file_B])

        d1 = d_layer(combined_files, self.df, bn=False)
        d2 = d_layer(d1, self.df * 2)
        d3 = d_layer(d2, self.df * 4)
        d4 = d_layer(d3, self.df * 8)

        validity = Conv2D(1, kernel_size=(4, 1), strides=1, padding='same')(d4)

        return Model([file_A, file_B], validity)

Here is my Keras generator:

 def build_generator(self):

        """U-Net Generator"""

        def conv2d(layer_input, filters, f_size=4, bn=True):

            """Layers used during downsampling"""

            d = Conv2D(filters, kernel_size=(f_size, 1), strides=(2, 1), padding='same')(layer_input)
            d = LeakyReLU(alpha=0.2)(d)
            if bn:
                d = BatchNormalization(momentum=0.8)(d)
            # print(d.shape)
            return d

        def deconv2d(layer_input, skip_input, filters, f_size=4, dropout_rate=0):

            """Layers used during upsampling"""

            u = UpSampling2D(size=(2, 1))(layer_input)
            u = Conv2D(filters, kernel_size=(f_size, 1), strides=1, padding='same', activation='relu')(u)
            if dropout_rate:
                u = Dropout(dropout_rate)(u)
            u = BatchNormalization(momentum=0.8)(u)
            u = Concatenate()([u, skip_input])
            # print(u.shape)
            return u

        d0 = Input(shape=self.file_shape)

        # Downsampling

        d1 = conv2d(d0, self.gf, bn=False)
        d2 = conv2d(d1, self.gf * 2)
        d3 = conv2d(d2, self.gf * 4)
        d4 = conv2d(d3, self.gf * 8)
        d5 = conv2d(d4, self.gf * 8)
        d6 = conv2d(d5, self.gf * 8)
        d7 = conv2d(d6, self.gf * 8)

        # Upsampling

        u1 = deconv2d(d7, d6, self.gf * 8)
        u2 = deconv2d(u1, d5, self.gf * 8)
        u3 = deconv2d(u2, d4, self.gf * 8)
        u4 = deconv2d(u3, d3, self.gf * 4)
        u5 = deconv2d(u4, d2, self.gf * 2)
        u6 = deconv2d(u5, d1, self.gf)

        u7 = UpSampling2D(size=(2, 1))(u6)

        # print(u7.shape)

        output_file = Conv2D(self.channels, kernel_size=4, strides=1, padding='same', activation='sigmoid')(u7)

        return Model(d0, output_file)

I would recommend to compare the number of parameters between both models.
Unfortunately, I don’t know, how to make the Keras code executable for both models.

As the next step, I would compare the output shapes, as e.g. your Keras model uses

layer_input = np.random.randn(1, 24, 24, 1)
validity = tf.keras.layers.Conv2D(1, kernel_size=(4, 1), strides=1, padding='same')(layer_input)
print(validity.shape)
> TensorShape([1, 24, 24, 1])

as the last conv layer in the discriminator, which doesn’t seem to output the same shape in the PyTorch code for a dummy input:

x = torch.randn(1, 1, 24, 24)
conv = nn.Conv2d(1, 1, kernel_size=(4, 1), padding=(1, 0), stride=1)
out = conv(x)
print(out.shape)
>  torch.Size([1, 1, 23, 24])

Once this is done, I would look into the layer initializations, as they are most likely different between the frameworks.