I am having an issue for my GAN model, when I am trying to train the discriminator. There is a portion of code that evaluates the discriminator with the real training data. This output makes up part of the loss function, but since I am doing multiple weight updates with every minibatch, this data does not change between iterations. Here is a pseudo-code snippet:
# Part 1: Train D on real
d_real_decision = D(d_real_data)
d_real_error = torch.mean(d_real_decision)
for d_index in range(c["d_steps"]):
D.zero_grad()
# Part 2: Train D on fake
d_gen_input = Variable(torch.rand(c["samples_per_batch"], c["g_input_size"], device=device))
d_fake_data = G(d_gen_input).detach()
d_fake_decision = D(d_fake_data)
d_fake_error = torch.mean(d_fake_decision)
# Part 3: Calculate Gradient Penalty (this is a custom function)
grad_penalty = gradient_penalty(d_real_data, d_fake_data, c, device)
# Part 4: Backward prop and gradient step
total_d_error = -1 * (d_real_error - d_fake_error) + grad_penalty
total_d_error.backward(retain_graph=False)
d_optimizer.step()
As you can see, the loss term total_d_error
includes d_real_error
, which does not change with every iteration. I want to avoid including “Part 1” in the for loop to reduce runtime. I’ve tried:
(1) Making cloned versions of d_real_error
inside the for loop. If I clone and detach, the gradients can’t backprop all the way back to the inputs of D. If I don’t detach, the “Part 1” graph is lost after the backward pass.
(2) retain_graph
True and False in the backward
command, combined with both options above in (1). Also, if I retain the graph with neither of the options in (1), and just have Part 1 inside the for loop, it breaks because other portions of the loss function are changing and have in-place operations.
Long story short, I don’t want to include Part 1 in the for loop, because of high computational cost. But, I can’t figure out a way to do this properly without anything else breaking. Any help would be appreciated!