WGAN-GP Very Negative Critic Loss

I am training a WGAN-GP on a 123 x 123 B&W Image. My model is training but I get very negative crit_loss values as far as Gen_loss = 6, and Crit_Loss = -106 at 400 epochs Screen Shot 2020-12-29 at 3.13.17 AM

So far, I have taken the following suggestions into consideration: made my critic symmetric to my generator, drawn from a larger z-dim, and added more layers to the generator. I would appreciate any other recommendations or adjustments that may alleviate this issue.

Generator Model:

class Generator(nn.Module):
    def __init__(self, z_dim=10, im_chan=1, hidden_dim=64):
        super(Generator, self).__init__()
        self.z_dim = z_dim
        # Build the neural network
        self.gen = nn.Sequential(
            self.make_gen_block(z_dim, hidden_dim * 8),
            self.make_gen_block(hidden_dim * 8, hidden_dim * 4, kernel_size=5, stride=2, padding = 1),
            self.make_gen_block(hidden_dim * 4, hidden_dim * 2, kernel_size=5, stride=2, padding = 1),
            self.make_gen_block(hidden_dim * 2, hidden_dim),
            self.make_gen_block(hidden_dim, im_chan, kernel_size=3, final_layer=True),
        )

    def make_gen_block(self, input_channels, output_channels, kernel_size=5, stride=2, padding = 1, final_layer=False):
        if not final_layer:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.ReLU(inplace=True),
            )
        else:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
                nn.Tanh(),
            )

    def forward(self, noise):
        x = noise.view(len(noise), self.z_dim, 1, 1)
        return self.gen(x)

Critic Model:

class Critic(nn.Module):
    def __init__(self, im_chan=1, hidden_dim=64):
        super(Critic, self).__init__()
        self.crit = nn.Sequential(
            self.make_crit_block(im_chan, hidden_dim),
            # self.make_crit_block(hidden_dim, hidden_dim * 8, kernel_size=5, stride=2, padding = 1),
            self.make_crit_block(hidden_dim, hidden_dim * 4, kernel_size=5, stride=2, padding = 1),
            self.make_crit_block(hidden_dim * 4, hidden_dim * 2, kernel_size=5, stride=2, padding = 1),
            self.make_crit_block(hidden_dim * 2, 1, kernel_size=3, final_layer=True),
        )
    def make_crit_block(self, input_channels, output_channels, kernel_size=4, stride=2, padding = 1, final_layer=False):
        if not final_layer:
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.LeakyReLU(0.2, inplace=True),
            )
        else:
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride),
            )

    def forward(self, image):
        crit_pred = self.crit(image)
        return crit_pred.view(len(crit_pred), -1)
1 Like