What is the right approach for dealing with `retain_graph=True`

I have simple GAN so I do following.

real_target = 1s
fake_target = 0s

gen_net.train()
dis_net.eval()
gen_optimizer.zero_grad()

generated = gen_net(noise)
gen_loss = loss_fn(dis_net(generated), real_target)

gen_loss.backward()

gen_net.eval()
dis_net.train()
dis_optimizer.zero_grad()

input_data = torch.cat(generated, real)
target = torch.cat(fake_target, greal_target)

output = dis_net(input_data)
dis_loss = loss_fn(output, target)

dis_loss.backward()

Yes, I am aware of the fact that once second backward() is called, the framework will attempt to backprop through generated, gen_net and dis_net twice and therefore raise warning.

Based on the code I have above, my understanding is that gen_net will be updated only from the first backward and dis_net will be updated only from the second backward as I call .eval() and .train() correctly prior to forward pass.
So, if I am safe to set retain_graph=True to get the training I want. Is my understanding correct?

Furthermore, I would like to deal with this correctly.
I believe if code is written correctly, I should not see retain_graph=True warning at all.
How should I change the above code so that I do not see retain_graph=True warning without explicitly setting retain_graph=True?

If your generator was already trained in the first step, you could try to detach the generated tensor from it before feeding it to the discriminator:

input_data = torch.cat(generated.detach(), real)

This will detach generated from the computation graph, so that the backward pass will stop at this point.

1 Like

Thanks it worked.
Out of curiosity, I have one more question

as you see, I am calling .eval(), and .train() correctly prior to each loss.backward().
Therefore, even though I just decided to add retain_graph=True to bypass the warning, the training should converge as if I detached the generated input.
Am I correct here?

1 Like

.eval() and .train() are not related to the gradient calculation and will just switch the behavior of some layers. E.g. dropout will be disabled and batch norm layers will use their internal running estimates instead of the batch stats.
The warnings/errors regarding a cleared graph should not change if you switch between eval and train.

2 Likes