How to use the backward functions for multiple losses?

Hello @Nabarun_Goswami,

to try to clear this up, in the DCGAN example, you have (think about mathematical functions here, I left out everything not relevant).

loss = criterion(netD(real, params))+criterion(netD(fake, params))

Spelling out the chain rule for the gradient of the loss w.r.t. the params:

params loss = ∇params netD(real, params)* ∇netD loss(netD (real,params)) + ∇params netD(fake, params)* ∇netD loss(netD (fake,params)),

note how ∇params netD is evaluated at two different points, namely (real, params) and (fake, params).

The way backpropagation works is to evaluate the gradients at the locations of the last forward pass.
In theory, you could also copy the network, make the parameters shared and then just add the loss to achieve the same. Then the backprop at real would go through one copy and the one at fake through the other.

Now, this is exactly why (I imagine, I didn’t design it) pytorch actually adds to the .grad on backward to allow the following:

  1. You zero gradients. (Ha, I forgot that often enough to need a benefit of needing to that myself.)
  2. You evaluate netD and criterion at the point real.
  3. You backprop to compute derivatives at the point real (=the last evaluated point). The .grads are added to zero from step 1.
  4. You evaluate netD and criterion at the point fake.
  5. You backprop to compute derivatives at the point fake (=the last evaluated point). The .grads are added to the .grads you had from step 3.

You have now computed the gradient of loss, but manually split it into the two summands.

If you just added the two parts to the loss and did backward, the netD would not know about step 2. anymore because step 4. overwrote things.

As seen in the Wasserstein GAN code and friends, you can also supply a -1 tensor to the .backward to emulate terms subtracted from the loss.

Now. If you just call .backward twice, there are two possibilities

  • with keep_graph=True (or keep_variables=True in pytorch <=0.1.12) in the first call, you will do the same as in 3 and five: You backprop twice to compute derivatives at the last evaluated point. The .grads are added to the .grads you had, so you end up computing twice the gradient at the last evaluated point.
  • without keep_graph=True in the first call, pytorch can throw away the processed bit’s information from forward, and will give an error meaning “forward info is gone, you used it and didn’t tell me to keep it”.

I hope this helps you to decide what is the fit for your project.

Best regards

Thomas

10 Likes