Pix2pixHD reimplimentation results != original implementation results

Hello, I’m working on integrating pix2pix into a framework of various generative models but running into issues replicating results from the original implementation.

I’ve essentially copied over the implementation from here into the same object oriented structure as the rest of my models. Then I’ve added in functionality to create the multiscale generator and discriminator architecture from the pix2pixHD paper.

However, when I train the model with those features disabled (which I believe is identical to the original pix2pix implementation) the results are much worse than the repo linked above.

Both my discriminator and generator have the exact same number of parameters and structure as the original repo (although there are a couple modules wrapped around them and named slightly differently)

Generator
DataParallel(
  (module): MultiscaleGenerator(
    (mainG): UnetGenerator(
      (model): UnetSkipConnectionBlock(
        (model): Sequential(
          (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
          (1): UnetSkipConnectionBlock(
            (model): Sequential(
              (0): LeakyReLU(negative_slope=0.2, inplace)
              (1): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
              (2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (3): UnetSkipConnectionBlock(
                (model): Sequential(
                  (0): LeakyReLU(negative_slope=0.2, inplace)
                  (1): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
                  (2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                  (3): UnetSkipConnectionBlock(
                    (model): Sequential(
                      (0): LeakyReLU(negative_slope=0.2, inplace)
                      (1): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
                      (2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                      (3): UnetSkipConnectionBlock(
                        (model): Sequential(
                          (0): LeakyReLU(negative_slope=0.2, inplace)
                          (1): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
                          (2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                          (3): UnetSkipConnectionBlock(
                            (model): Sequential(
                              (0): LeakyReLU(negative_slope=0.2, inplace)
                              (1): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
                              (2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                              (3): UnetSkipConnectionBlock(
                                (model): Sequential(
                                  (0): LeakyReLU(negative_slope=0.2, inplace)
                                  (1): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
                                  (2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                                  (3): UnetSkipConnectionBlock(
                                    (model): Sequential(
                                      (0): LeakyReLU(negative_slope=0.2, inplace)
                                      (1): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
                                      (2): ReLU(inplace)
                                      (3): ConvTranspose2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
                                      (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                                    )
                                  )
                                  (4): ReLU(inplace)
                                  (5): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
                                  (6): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                                  (7): Dropout(p=0.5)
                                )
                              )
                              (4): ReLU(inplace)
                              (5): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
                              (6): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                              (7): Dropout(p=0.5)
                            )
                          )
                          (4): ReLU(inplace)
                          (5): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
                          (6): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                          (7): Dropout(p=0.5)
                        )
                      )
                      (4): ReLU(inplace)
                      (5): ConvTranspose2d(1024, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
                      (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                    )
                  )
                  (4): ReLU(inplace)
                  (5): ConvTranspose2d(512, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
                  (6): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                )
              )
              (4): ReLU(inplace)
              (5): ConvTranspose2d(256, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
              (6): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (2): ReLU(inplace)
          (3): ConvTranspose2d(128, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
          (4): Tanh()
        )
      )
    )
  )
)
[Network G] Total number of parameters : 54.414 M
Discriminator
DataParallel(
  (module): MultiscaleDiscriminator(
    (discriminator_0): NLayerDiscriminator(
      (model): Sequential(
        (layer_0): Sequential(
          (0): Conv2d(6, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
          (1): LeakyReLU(negative_slope=0.2, inplace)
        )
        (layer_1): Sequential(
          (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): LeakyReLU(negative_slope=0.2, inplace)
        )
        (layer_2): Sequential(
          (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
          (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): LeakyReLU(negative_slope=0.2, inplace)
        )
        (layer_3): Sequential(
          (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): LeakyReLU(negative_slope=0.2, inplace)
        )
        (final_conv): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
        (sigmoid): Sigmoid()
      )
    )
  )
)
[Network D] Total number of parameters : 2.769 M

However my forward() code for the discriminator is different (to allow for extracting intermediate features for the discriminator feature loss from pix2pixHD). I think this is most likely where the issue is as the generator loss values are very similar to the original implementation while the discriminator’s losses are about an order of magnitude smaller.

My training code
for epoch in range(start_epoch, num_epochs + epochs_decay + 1):
    for i, data in enumerate(dataloader):
        # get data
        real_A = data['A'].to(self.device)
        real_B = data['B'].to(self.device)
        fake_B = self.G(real_A)

        # optimize discriminator
        self.set_requires_grad(self.D, True)
        self.optimizer_D.zero_grad()

        fake_AB = th.cat((real_A, fake_B), 1)
        pred_fake = self.D(fake_AB.detach())

        loss_D_fake = 0
        for preds in pred_fake:
            loss_D_fake += self.loss_GAN(preds, self.fake_label.expand_as(preds))

        real_AB = th.cat((real_A, real_B), 1)
        pred_real = self.D.get_features(real_AB)

        loss_D_real = 0
        for preds in self.pred_real:
            loss_D_real += self.loss_GAN(preds[-1], self.real_label.expand_as(preds[-1]))

        loss_D = (loss_D_fake + loss_D_real) * 0.5
        loss_D.backward()
        self.optimizer_D.step()

        # optimize generator
        self.set_requires_grad(self.D, False)
        self.optimizer_G.zero_grad()
        
        fake_AB = th.cat((real_A, fake_B), 1)
        pred_fake = self.D.get_features(fake_AB)

        loss_G_GAN = 0
        for preds in pred_fake:
            loss_G_GAN += self.loss_GAN(preds[-1], self.fake_label.expand_as(preds[-1]))

        loss_G_L1 = self.loss_L1(fake_B, real_B) * self.lambda_L1

        loss_G = loss_G_GAN + loss_G_L1

       if self.feature_loss: # false in this case
            # do stuff with pred_real and pred_fake here

        loss_G.backward()
        self.optimizer_G.step()

    def get_features(self, input):
        # forward with input downsampled for each scale
        # return list of lists of activations at each layer for each discriminator scale
        results = []
        for scale in range(self.n_scales): # in this case n_scales = 1
            # scale input for given discrim scale
            prev_output = interpolate(input, scale_factor=2**(scale - self.n_scales + 1))

            # get separate layers of discriminator
            discriminator = getattr(self, 'discriminator_%s'%(scale)).model
            layers = {k:v for (k,v) in discriminator.named_children() if 'layer' in k}

            # forward through the network storing discriminator features by layer
            per_scale_results = []
            for idx in range(len(layers)):
                layer = layers['layer_%s'%idx]
                prev_output = layer(prev_output)
                per_scale_results.append(prev_output)

            final_preds = getattr(discriminator, 'final_conv')(per_scale_results[-1])
            per_scale_results.append(final_preds)

            if self.use_sigmoid: # true in this case
                final_preds = getattr(discriminator, 'sigmoid')(final_preds)
                per_scale_results.append(final_preds)

            results.append(per_scale_results)

        # results look like this:
        # [[torch.Size([1, 64, 128, 128]), torch.Size([1, 128, 64, 64]), torch.Size([1, 256, 32, 32]), torch.Size([1, 512, 31, 31]), torch.Size([1, 1, 30, 30]), torch.Size([1, 1, 30, 30])]]
        # the last element in the list are the final predictions in the same format as the original pix2pix repo
        return results

    def forward(self, input):
        # forward with input downsampled for each scale
        # return list of predictions for each discriminator scale
        results = []
        for scale in range(self.n_scales):
            discriminator = getattr(self, 'discriminator_%s'%(scale))
            downsampled_input = interpolate(input, scale_factor=2**(scale - self.n_scales + 1))
            predictions = discriminator(downsampled_input)
            results.append(predictions)

        # results look like this: [[torch.Size([1, 1, 30, 30])]]
        return results
Original training code
for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
    for i, data in enumerate(dataset):
        self.real_A = data['A'].to(self.device)
        self.real_B = data['B'].to(self.device)
        self.fake_B = self.netG(self.real_A)

        # update D
        self.set_requires_grad(self.netD, True)
        self.optimizer_D.zero_grad()

        fake_AB = torch.cat((self.real_A, self.fake_B), 1)
        pred_fake = self.netD(fake_AB.detach())
        self.loss_D_fake = self.criterionGAN(pred_fake, False)

        real_AB = torch.cat((self.real_A, self.real_B), 1)
        pred_real = self.netD(real_AB)
        self.loss_D_real = self.criterionGAN(pred_real, True)

        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5

        self.loss_D.backward()
        self.optimizer_D.step()

        # update G
        self.set_requires_grad(self.netD, False)
        self.optimizer_G.zero_grad()

        fake_AB = torch.cat((self.real_A, self.fake_B), 1)
        pred_fake = self.netD(fake_AB)
        self.loss_G_GAN = self.criterionGAN(pred_fake, True)

        self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1

        self.loss_G = self.loss_G_GAN + self.loss_G_L1

        self.loss_G.backward()
        self.optimizer_G.step()

loss_GAN and criterionGAN from the respective implementations are nn.BCELoss()
I’ve verified that the data is being normalized in the exact same way and are identical before being given to forward()

Just to give an idea of the difference in results:

After 45 epochs mine looks like this:


while the original implementation looks like this:

Also I tried loading the checkpoint of the original implementation from epoch 35 and continuing training with my own implementation. After 1 epoch it already looks very faded:


Compared with the checkpoint it’s loaded from:
epoch035_fake_B

After 10 epochs of continued training with my implementation it essentially looks the same as having trained it from scratch

Clearly my implementation is doing something wrong, but I can’t for the life of me figure out what’s going wrong. Anyone have an idea what the issue could be?

The full codebase can be found here