What exactly does `retain_variables=True` in `loss.backward()` do?

I understand If I have two loss functions in different parts of the network, I’ll have use retain_graph. What if I add both the losses and do total_loss.backward() ?

for example:


Rather than

loss1.backward(retain_graph=True)
loss2.backward()
opt.step()

I would just do

total_loss = loss1 + loss2
total_loss.backward()
opt.step()
8 Likes