-
you need to do
A.grad.data.zero_()
before the second pass so that there isn’t accumulation of gradients from the first backward (see here). I you are doing this in a loop, then in the second iteration you will needA.grad.data.zero_()
before gettingA_task1_grad
as well. -
in the current code if you print
A_task1_grad
andA_task2_grad
after the secondbackward
you will find they are the same becausebackward
keeps modifying the same memory location. you need to doA_task1_grad = A.grad.clone()
to make a copy of thegrad
to a new location that will not be modified bybackward
the next time.
2 Likes