How to get gradient of a tensor wrt different losses

Hi,

Just switch to pytorch. Have a question here. Suppose a multi-task settings.

task1_preds, task2_preds = self.model(input)
task1_loss = self.crit(task1_preds, task1_labels)
task2_loss = self.crit(task2_preds, task2_labels)

I want to get the gradients of a tensor A wrt these two losses, like d task1_loss (A), d task2_loss(A)

task1_loss.backward(retain_graph=True)
A_task1_grad = A.grad
task2_loss.backward(retain_graph=True)
A_task2_grad = A.grad

Just wondering if this should work. Any ideas? Thanks.

  1. you need to do A.grad.data.zero_() before the second pass so that there isn’t accumulation of gradients from the first backward (see here). I you are doing this in a loop, then in the second iteration you will need A.grad.data.zero_() before getting A_task1_grad as well.

  2. in the current code if you print A_task1_grad and A_task2_grad after the second backward you will find they are the same because backward keeps modifying the same memory location. you need to do A_task1_grad = A.grad.clone() to make a copy of the grad to a new location that will not be modified by backward the next time.

2 Likes

Thanks a lot for your answer.

I made these changes according to your reply.

A.grad.data.zero_()
task1_loss.backward(retain_graph=True)
A_task1_grad = A.grad.clone()
A.grad.data.zero_()
task2_loss.backward(retain_graph=True)
A_task2_grad = A.grad