Ha, I recently did exactly this. Not sure if its the best way, but I did:
- detach
yybefore feeding to getzzs, e.g.yyy = y.detach() - Manually call
autograd.gradto get each ofzzs grad w.r.t.yyy. - Save the one you want
- call
yy.backward(grad_to_yyy_1 + grad_to_yyy_2).