I understand If I have two loss functions in different parts of the network, I’ll have use retain_graph. What if I add both the losses and do total_loss.backward() ?
for example:
Rather than
loss1.backward(retain_graph=True)
loss2.backward()
opt.step()
I would just do
total_loss = loss1 + loss2
total_loss.backward()
opt.step()