 # How to separate gradients from different losses on the same layer

Pytorch seems to only aggregate (accumulate) gradients on the same layer.

``````optimizer.zero_grad()
loss1.backward(retain_graph=True)
loss2.backward(retain_graph=True)
loss3.backward(retain_graph=True)

loss = loss1 + loss2 + loss3
loss.backward()
``````

I found that `gradient_1 + gradient_2 + gradient_3 != gradient_4`. Why?
How can I get the exact three gradients by the three losses, which should add up to `gradient_4` ?

## Feel free to discuss!

Specifically, by comparing them step by step, I saw that `(gradient_1 + gradient_2 + gradient_3).norm()` equals to `gradient_4.norm()`, however, `(gradient_1 + gradient_2 + gradient_3).norm() == gradient_4.norm()` is False. This is consistent after hundreds of iterations… Is it because of numerical instability? If so, does above code the right way to inspect different gradients by different losses?

Use `torch.allclose` to compare two float tensors if you think the issue might be due to numerical instability. On printing `grad1+grad2+grad3` and `grad4`, are the values different?

1 Like

Hi, I tried `torch.allclose(grad1+grad2+grad3, grad4)` and it returns `True`. That means the two gradient matrices are element-wise equal within a tolerance. The printed values of them are the same but they just do not equal each other.

So, how to explain this? Why the addition of `gradient_1 + gradient_2 + gradient_3` is slightly different from using `loss = loss1 + loss2 + loss3` to do backpropagation at a time? shouldn’t they differentiate on the same graph three times and accumulate (add) gradients to be `gradient_4`?

You can only get 6 digits of precision with `torch.float32`, so if the gradient values are same till 6 digits of precision then it means your observations are a limitation of `float32`. Maybe adding `loss1+loss2+loss3` results in the floating point error. Try printing `loss` and `loss1, loss2, loss3` individually to see if they add up.

You can use `torch.set_printoptions(precision=10)` to show 10 digits in float.