Assume you have 2 modules: module G and module D that are in sequence just like in a standard GAN setup. I would like to train them using a SINGLE optimizer that optimizes the parameters of BOTH modules. It is easy to exclude the first module (G) from the optimisation for the discriminator loss by simply using detach. However, is it possible to prevent the accumulation of gradient in the discriminator when calculating gradients for the generator? For example something like:
fakes = G(x) with torch.no_backward_accumulation(): scores =D(fakes) generator_loss = some_criterion(scores)
GAN is just an example, in my case I have multiple modules and sometimes I want to exclude a middle one from the optimisation. Thanks!