Unexpected behavior when summing multiple criterions

Hello,

I am trying to train a GAN in pytorch, and as training the “discriminator” will require doing based on the discriminator output on “real”, and “fake” samples. I tried doing this in two different ways that I expected to be equivalent, but seems to be only one of them works.

First approach: (Didn’t work) - summing the two criterions and backward the combined loss.

criterion = nn.BCELoss()
d_out_real = model_d(real_input)
real_loss = criterion(d_out_real, real_label)
d_out_fake = model_d(fake_input)
fake_loss = criterion(d_out_fake, fake_label)
total_loss = real_loss + fake_loss

d_optimizer.zero_grad() # Adam optimizer to optimize model_d parameters
total_loss.backward()
d_optimizer.step()

Second approach: (Works fine): backward each criterion loss individually, then taking an optimization step.

criterion = nn.BCELoss()
d_optimizer.zero_grad() # Adam optimizer to optimize model_d parameters
d_out_real = model_d(real_input)
real_loss = criterion(d_out_real, real_label)
real_loss.backward()
d_out_fake = model_d(fake_input)
fake_loss = criterion(d_out_fake, fake_label)
fake_loss.backward()
d_optimizer.step()

Third approach: (Doesnot work): Exactly like second approach, except for delaying the real_loss.backward() until after I compute the model output and loss from fake data.

criterion = nn.BCELoss()
d_optimizer.zero_grad() # Adam optimizer to optimize model_d parameters
d_out_real = model_d(real_input)
real_loss = criterion(d_out_real, real_label)
d_out_fake = model_d(fake_input)
fake_loss = criterion(d_out_fake, fake_label)
real_loss.backward()
fake_loss.backward()
d_optimizer.step()

Notes:

  • I tried using two BCELoss modules, but it didn’t help.
  • I also tried computing the loss using the functional APIs torch.nn.functional.binary_crossentropy_loss. Also this changed nothing.

Any explanations ?

Thanks,
Moustafa

Hi,

I would expect the three to have the exact same behaviour.
Have you ran each multiple times, isn’t this working/not working because of other stability reasons of GAN that do not always converge for different runs?

Thanks @albanD… Yeah, I ran many times and always getting same results (Only case 2 works)

Could you give some code to reproduce the issue? Do you have custom Functions?

Thanks @albanD. My code (using the working scenario case 2) is available on Github here:

To switch to case 1: you can comment line 170 and replace line 184 by the following line:

errD.backward()

To switch to case 3: you can move line 170 from its position and place it at the top of line 184.

Thanks,
Moustafa