loss.backward(retain_graph=True) vs summing losses and called .backward() once

Hi,

I am training an RNN, and computing a loss for each step of the output sequence. Is calling loss.backward(retain_graph=True) for each step of the output more memory efficient/faster than summing loss = loss_step1 + … + loss_stepn and then calling loss.backward()? I have tried both, and they seem to give the same accuracy. But I’m not sure what is different behind the scenes.

Hi,

summing the losses and calling a single backward is going to be more efficient unless each loss work with independent networks and it will be the same.
But basically you can’t go wrong with doing the sum and a single backward !.

1 Like