How to get the gradient of parameters for different losses when multiple losses used

# construction, simply
net = Net()
data, label = DataLoader(Dataset)
opt = Optimizer()
criterion1 = Loss1()
criterion2 = Loss2()
output = net(data)
# get loss
loss1 = criterion1(output, label)
loss2 = criterion2(output, label)

# call backward()
# first type
loss1.backward() 
print(next(net.parameters()).grad)
loss2.backward()
print(next(net.parameters()).grad)
# is the above same as below
# second type
loss = loss1+loss2
loss.backward()

# step()
opt.step()

I have two questions.
First, two kinds of backward is equal?
Second, if i want to see the gradient of parameters for different loss, is it the simplest way? Or some more simple way to get the gradient for different losses.

To answer your first question: you need to specify retain_graph=True in your first backward-call when using multiple losses for one output and backwarding them separately. Otherwise the gradient graph will be destroyed during your first backward-call and you will be unable to use the second loss.

Whether or not these approaches are equal depends on how you are using the optimizer. If you call optim.zero_grad() between the backward calls these approaches are not equivalent but if you don’t do so that they are since gradients are per default accumulated in Pytorch.

1 Like

Thank you very much.

This might be a very silly question, but to get the gradients of parameters for different losses (as in the question) don’t we have to subtract the gradients for the second update? I’m basing this on the fact that we don’t do an optim.zero_grad() in between (which will anyway lead to wrong back propagation as suggested). To be clear, shouldn’t the below code be used rather than what the question suggests?

opt.zero_grad()

loss1.backward()
grad_wrt_loss1 = next(net.parameters()).grad
print(grad_wrt_loss1)
loss2.backward()
# THIS subtraction is the distinction from above code:
grad_wrt_loss2 = next(net.parameters()).grad - grad_wrt_loss1
print(grad_wrt_loss2)

opt.step()