If you want to call backward()
on running_loss
, you should just add the losses together without calling .item()
on them.
You would need to call item()
to detach the loss from the computation graph (no backward possible anymore) and store them for debugging purposes, e.g. printing.
In your case however, you need the computation graphs.
Another approch would be to call (loss_a + loss_b).backward()
inside the inner loop. This should yield the same results as the gradients are accumulated. Have a look at this post for more information.
4 Likes