Calling loss.backward() multiple times vs. aggregating losses

I have some code that uses multiple loss functions and therefore calls loss.backward(retain_graph=True) multiple times.

I heard that doing only a single backward pass can be more efficient (time-wise). Should I in that case sum all the losses or average them?

And is there any difference between calling backward on each loss separately aggregating them?

Calling loss.backward multiple times on different losses is equivalent to summing the individual losses first and backwarding a single time on the summed loss. Averaging them will mean your gradients will be dampened by a factor of n. It is the same if you don’t care about the gradients wrt to any of the losses individually, and yeah it will be more efficient depending on how much of the graph the losses share.

A better way is passing in multiple losses to torch.autograd.backward at once (see torch.autograd.backward — PyTorch 1.9.1 documentation) . It also does the same thing but avoids performing the unnecessary sum.

1 Like