Backward() of 2 losses (explained)

Hello, this is quite a simple question but I didn’t find an answer.

When you have 2 losses the approach I read is to sum up the losses and then apply .backward()

My question is about understanding this.

To get Loss1 I had some inputs (x1’s) and the output (y1)
To get Loss2 I had some inputs (x2’s) and the output (y2)

If I would have Loss1 only then backward() will give me the gradient with respect to my parameters. But note that I used x1’s as input so I will have Z1= wx+b and A1 (Activation) values.

In contrast, for Loss2 I would have Z2 and A2.
Z1,A1,Z2,and A2 are different because different inputs were used.

Now you tell me that I just need to sum up L1+L2 and use backward(). But to compute the gradient of a particular parameter, I will have da/dZ and/or dZ/dA in the chain (chain rule differentiation). So which one is taken, Z1 or Z2? A1 or A2? How exactly differentiation takes place?

Elsewhere I read that summing up the losses is similar to:
L1.backward()
L2.backward()

Is this correct? If so, this second case looks more clear, because L1 is backpropagated with Z1 and A1, and L2 with Z2 and A2, but then I’m back to the same question. If I manage to get dL1/dW and dL2/dW, how do I update my parameter…? Would it be:
W = W_prev - learning_rate * dL1/dW - learning_rate * dL2/dW
so in the end I sum up for each gradient obtained separately?

Would you explain me how it goes because I’m confused.
Regards

Assuming both computation graphs are independent (as indicated by ZX, AX) then the gradient will be accumulated and independently calculated as seen here:

# setup
model = nn.Linear(1, 1, bias=False)

x1 = torch.randn(1, 1)
x2 = torch.randn(1, 1)

# sum
out1 = model(x1)
out2 = model(x2)

loss = out1 + out2
loss.backward()

print(model.weight.grad)
# tensor([[0.8689]])

# separate
model.zero_grad()

out1 = model(x1)
out2 = model(x2)

out1.backward()
print(model.weight.grad)
# tensor([[-0.6130]])

out2.backward()
print(model.weight.grad)
# tensor([[0.8689]])
1 Like

Dear @ptrblck Thanks a lot I finally understood it. Your example was very clear