def train_discriminator(optimizer, real_data, fake_data):
#set optimizer gradients to zero to store fresh gradients
optimizer.zero_grad()
#adversarial discriminator loss
prediction_real = disc(real_data)
error_real = loss_bce(prediction_real, real_data_target(real_data.size(0)))
#adversarial generative loss
prediction_fake = disc(fake_data)
error_fake = loss_bce(prediction_fake, fake_data_target(fake_data.size(0)))
total_error = (error_real + error_fake)/2
total_error.backward() #backprop
optimizer.step() #update weights
return (error_real + error_fake)/2, prediction_real, prediction_fake
def train_generator(optimizer, real_data, fake_data):
optimizer.zero_grad()
#content loss
content_loss = loss_mse(real_data, fake_data)
#adversarial generative loss
prediction_fake = disc(fake_data)
adv_loss = loss_bce(prediction_fake, real_data_target(fake_data.size(0)))
total_gen_loss = content_loss + 0.001*adv_loss
total_gen_loss.backward() #backprop
optimizer.step() #update weights
return total_gen_loss
for epoch in range(EPOCHS):
g_err_sum = 0.0
d_err_sum = 0.0
for real_batch in training_loader:
real_batch = real_batch.cuda()
real_data = real_batch
fake_data = gen(real_data[:,0:2]) # Line 1
#train discriminator
d_error, _ ,_ = train_discriminator(d_opt, real_data, fake_data.detach()) # Line 2
d_err_sum+=d_error
#train generator
g_error = train_generator(g_opt, real_data, fake_data)
g_err_sum+=g_error
errorsd.append(d_err_sum/num_batches)
errorsg.append(g_err_sum/num_batches)
if epoch%10==0:
print("Epoch no : {} Discriminator error: {} Generator error: {}".format(epoch,d_err_sum/num_batches,g_err_sum/num_batches))
In the main loop, changing gen(real_data[:,0:2]) to gen(real_data[:,0:2]).detach()
is failing to train the generator. The initial snippet of code is successfully making a GAN model.
Kindly someone explain how .detach() is making a difference, especially in context with the train_generator() function written above.
Adding a .detach() basically breaks the gradient connection.
This means that any gradient flowing back towards fake_data won’t be propagated to the generator. So no gradient will be populated.
Because you want the discriminator loss to only compute gradients for the discriminator and not the generator.
The .detach() here allows you to make sure this happens.
If you do it for the generator loss, then the generator loss won’t contribute to the generator gradient which is not what you want.
In this example, it seems that he zeros out the gradients of the generator before updating the weights. In this context, does it still make a difference to the training process if we detach or not, except for the extra unnecessary computation?
I am printing gradients of a layer of Generator, with and without using .detach(). In my thinking the gradients of weights should not change when calling discriminator_loss.backward while using .detach()(since .detach() ensures the gradients are not being backpropagated to the generator), but I am observing opposite behavior. Irrespective of usage of .detach() the before and after gradients value are different when discriminator_loss.backbard() is called. Can anyone point out where I am wrong?
netD.zero_grad()
#netG.zero_grad()
# Format batch
real_cpu = data[0].to(device)
b_size = real_cpu.size(0)
label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
# Forward pass real batch through D
output = netD(real_cpu).view(-1)
# Calculate loss on all-real batch
errD_real = criterion(output, label)
# Calculate gradients for D in backward pass
errD_real.backward()
D_x = output.mean().item()
## Train with all-fake batch
# Generate batch of latent vectors
noise = torch.randn(b_size, nz, 1, 1, device=device)
# Generate fake image batch with G
print(netG.main[0].weight[0][0][1:3], "generator before netG(noise)")
fake = netG(noise)
label.fill_(fake_label)
# Classify all fake batch with D
output = netD(fake.detach()).view(-1)
#output = netD(fake).view(-1)
# Calculate D's loss on the all-fake batch
errD_fake = criterion(output, label)
# Calculate the gradients for this batch
# In my thinking both the following print functions should provide same values when using .detach()
print(netG.main[0].weight[0][0][1:3], "generator before errD_fake.backward() step")
errD_fake.backward()
print(netG.main[0].weight.grad[0][0][1:3], "generator grad after errD_fake.backward() step")