I am having an issue for my GAN model, when I am trying to train the discriminator. There is a portion of code that evaluates the discriminator with the real training data. This output makes up part of the loss function, but since I am doing multiple weight updates with every minibatch, this data does not change between iterations. Here is a pseudo-code snippet:
# Part 1: Train D on real d_real_decision = D(d_real_data) d_real_error = torch.mean(d_real_decision) for d_index in range(c["d_steps"]): D.zero_grad() # Part 2: Train D on fake d_gen_input = Variable(torch.rand(c["samples_per_batch"], c["g_input_size"], device=device)) d_fake_data = G(d_gen_input).detach() d_fake_decision = D(d_fake_data) d_fake_error = torch.mean(d_fake_decision) # Part 3: Calculate Gradient Penalty (this is a custom function) grad_penalty = gradient_penalty(d_real_data, d_fake_data, c, device) # Part 4: Backward prop and gradient step total_d_error = -1 * (d_real_error - d_fake_error) + grad_penalty total_d_error.backward(retain_graph=False) d_optimizer.step()
As you can see, the loss term
d_real_error, which does not change with every iteration. I want to avoid including “Part 1” in the for loop to reduce runtime. I’ve tried:
(1) Making cloned versions of
d_real_error inside the for loop. If I clone and detach, the gradients can’t backprop all the way back to the inputs of D. If I don’t detach, the “Part 1” graph is lost after the backward pass.
retain_graph True and False in the
backward command, combined with both options above in (1). Also, if I retain the graph with neither of the options in (1), and just have Part 1 inside the for loop, it breaks because other portions of the loss function are changing and have in-place operations.
Long story short, I don’t want to include Part 1 in the for loop, because of high computational cost. But, I can’t figure out a way to do this properly without anything else breaking. Any help would be appreciated!