Batchnorm issues for discriminators in DCGAN

Hi all,

I am currently pre-training a discriminator with the following architecture:

WavyDiscriminator(
  (model): Sequential(
    (0): UpsampledTrilinearBlock2d(
      (block): Sequential(
        (0): Conv2d(9, 18, kernel_size=(5, 5, 5), stride=(2, 2, 2), padding=(1, 1, 1), bias=False, padding_mode=replicate)
        (1): LeakyReLU(negative_slope=0.1, inplace=True)
        (2): BatchNorm2d(18, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): SpecConvBlock2d(
      (block): Sequential(
        (0): Conv2d(18, 36, kernel_size=(5, 5, 5), stride=(2, 2, 2), padding=(1, 1, 1), bias=False, padding_mode=replicate)
        (1): LeakyReLU(negative_slope=0.1, inplace=True)
        (2): Dropout2d(p=0.01, inplace=False)
        (3): BatchNorm2d(36, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (2): SpecConvBlock2d(
      (block): Sequential(
        (0): Conv2d(36, 72, kernel_size=(5, 5, 5), stride=(2, 2, 2), padding=(1, 1, 1), bias=False, padding_mode=replicate)
        (1): LeakyReLU(negative_slope=0.1, inplace=True)
        (2): Dropout2d(p=0.01, inplace=False)
        (3): BatchNorm2d(72, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (3): ResNetBasicBlock(
      (blocks): Sequential(
        (0): Sequential(
          (0): Conv2dAuto(72, 72, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
          (1): BatchNorm2d(72, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): ReLU(inplace=True)
        (2): Sequential(
          (0): Conv2dAuto(72, 72, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
          (1): BatchNorm2d(72, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (activate): LeakyReLU(negative_slope=0.01, inplace=True)
      (shortcut): None
    )
    (4): SpecConvBlock2d(
      (block): Sequential(
        (0): Conv2d(72, 144, kernel_size=(3, 3, 3), stride=(1, 1, 1), bias=False, padding_mode=replicate)
        (1): LeakyReLU(negative_slope=0.1, inplace=True)
        (2): Dropout2d(p=0.01, inplace=False)
        (3): BatchNorm2d(144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (5): ResNetBasicBlock(
      (blocks): Sequential(
        (0): Sequential(
          (0): Conv2dAuto(144, 144, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
          (1): BatchNorm2d(144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): ReLU(inplace=True)
        (2): Sequential(
          (0): Conv2dAuto(144, 144, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
          (1): BatchNorm2d(144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (activate): LeakyReLU(negative_slope=0.01, inplace=True)
      (shortcut): None
    )
    (6): SpecConvBlock2d(
      (block): Sequential(
        (0): Conv2d(144, 288, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), bias=False, padding_mode=replicate)
        (1): LeakyReLU(negative_slope=0.1, inplace=True)
        (2): Dropout2d(p=0.01, inplace=False)
        (3): BatchNorm2d(288, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (7): ResNetBasicBlock(
      (blocks): Sequential(
        (0): Sequential(
          (0): Conv2dAuto(288, 72, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
          (1): BatchNorm2d(72, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): ReLU(inplace=True)
        (2): Sequential(
          (0): Conv2dAuto(72, 72, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
          (1): BatchNorm2d(72, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (activate): LeakyReLU(negative_slope=0.01, inplace=True)
      (shortcut): Sequential(
        (0): Conv2d(288, 72, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
        (1): BatchNorm2d(72, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (8): SpecConvBlock3d(
      (block): Sequential(
        (0): Conv2d(72, 288, kernel_size=(5, 3, 2), stride=(1, 1, 1), bias=False, padding_mode=replicate)
        (1): LeakyReLU(negative_slope=0.1, inplace=True)
        (2): Dropout2d(p=0.01, inplace=False)
        (3): BatchNorm2d(288, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (9): SpecConvBlock2d(
      (block): Sequential(
        (0): Conv2d(288, 72, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False, padding_mode=replicate)
        (1): LeakyReLU(negative_slope=0.1, inplace=True)
        (2): Dropout2d(p=0.01, inplace=False)
        (3): BatchNorm2d(72, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (10): OutConvBlock2d(
      (block): Sequential(
        (0): Conv2d(72, 1, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False, padding_mode=replicate)
      )
    )
    (11): Flatten()
  )
)

My issue is that I am trying to pre-train it in a similar manner to how I will train it during the DCGAN optimisation loop. As suggested (apex/main_amp.py at master · NVIDIA/apex · GitHub), I use an “all fake” and an “all real” batch. In the pre-training loop I have the following:

# Inner optimisation loop.
    with amp.autocast(enabled = enable_autocast):
        # Fake loss
        y_fake = model(fake)
        y_real = model(real)
        loss = BCEWithLogitsLoss(torch.vstack([y_fake, y_real]), torch.vstack([self.fake_labels, self.true_labels]))
accum_loss = accum_loss + loss.clone().item()/self.write_interval
scaler.scale(loss).backward()

The system seems to fall into the local optima, at a loss of 0.693 (i.e. predicting 0.5 for fake and real samples). Whenever I train the system by stacking fake and real samples into the same forward pass it works. However, when doing so in a similar fashion to DCGAN, it doesn’t.

I think it has to do with the batchnorm layer since when testing the setting that works in .train() mode, it fails to produce accurate labels when the fake/real ratio differs from that in the training setup. Moreover, when removing the batchnorm layers in all SpecConvBlock modules (I can keep them in the ResdiualBasicBlock module) it works. Which reinforces the idea that it has to do with the batchnorm layer, as suggested.

Summary:
In order to pre-train the discriminator properly, I have to pre-train it in an “all fake” and “all real” manner so that the batchnorm layers can cope with this and I am not sure how to solve this issue without removing these. In addition, not sure how this is not an issue for DCGAN, given that the normalisation of “fake” and “real” images will produce significantly different representations.

4 Likes

I’ve come across the same problem actually, would be keen to know how to solve this