-
you need to do
A.grad.data.zero_()before the second pass so that there isn’t accumulation of gradients from the first backward (see here). I you are doing this in a loop, then in the second iteration you will needA.grad.data.zero_()before gettingA_task1_gradas well. -
in the current code if you print
A_task1_gradandA_task2_gradafter the secondbackwardyou will find they are the same becausebackwardkeeps modifying the same memory location. you need to doA_task1_grad = A.grad.clone()to make a copy of thegradto a new location that will not be modified bybackwardthe next time.
2 Likes