Exclude subgraph from gradient accumulation

(Piotr Dabkowski) #1

Assume you have 2 modules: module G and module D that are in sequence just like in a standard GAN setup. I would like to train them using a SINGLE optimizer that optimizes the parameters of BOTH modules. It is easy to exclude the first module (G) from the optimisation for the discriminator loss by simply using detach. However, is it possible to prevent the accumulation of gradient in the discriminator when calculating gradients for the generator? For example something like:

fakes = G(x)
with torch.no_backward_accumulation():
    scores =D(fakes)
generator_loss = some_criterion(scores)

GAN is just an example, in my case I have multiple modules and sometimes I want to exclude a middle one from the optimisation. Thanks!

(Simon Wang) #2
fake_d = fake.detach()
score = D(fake_d)
grad_fake = autograd.grad(score, fake_d)

What is beyond me is why you can’t use two optimizers.

(Piotr Dabkowski) #3

Thanks, this could be one way to solve that - by manually passing the gradients to backward. Using 2 optimisers in a simple GAN case is fine. However, for example if you have N modules in sequence and you want to only optimize a randomly selected subset of modules then creating optimizer for every new subset or manually passing the gradients (as you suggested) would be a pain. My use case is even more complicated.

I think that having such a no_backward_grad_accumulation context would be extremely useful, as it would provide a very simple (single line) way of excluding ANY subgraph from optimization. I am currently reading through the autograd code to find out whether its possible.

(Simon Wang) #4

I understand your use case now. :slight_smile:

You can also do this:

fake = G(z)
for p in D.parameters():
  p.requires_grad = False
loss = criterion(D(fake))
for p in D.paramters():
  p.requires_grad = True

And you can write a function instead of spelling out the two for loops.

(Piotr Dabkowski) #5

Yes, thanks, I thought about it as well, this is actually a method used by WGAN-GP PyTorch implementation, but as I said my case is even more complex :wink: For example, what if you want to actually also optimize D but using a different loss term? You would need another optimizer to do that or to manually pass the gradients. Using no_backward_accumulation you could optimize G and D very easily jointly:

fakes = G(x)
with torch.no_backward_accumulation():
    scores = D(fakes)
generator_loss = some_criterion(scores)
discriminator_loss = another_criterion(D(fakes.detach(), reals))
loss = generator_loss + discriminator_loss

Will try to find a solution to this problem and if you think that such a feature is useful I can submit a pull request.