How to get gradient of a tensor wrt different losses

  1. you need to do A.grad.data.zero_() before the second pass so that there isn’t accumulation of gradients from the first backward (see here). I you are doing this in a loop, then in the second iteration you will need A.grad.data.zero_() before getting A_task1_grad as well.

  2. in the current code if you print A_task1_grad and A_task2_grad after the second backward you will find they are the same because backward keeps modifying the same memory location. you need to do A_task1_grad = A.grad.clone() to make a copy of the grad to a new location that will not be modified by backward the next time.

2 Likes