How to separate gradients from different losses on the same layer

Pytorch seems to only aggregate (accumulate) gradients on the same layer.

optimizer.zero_grad()
loss1.backward(retain_graph=True)
gradient_1 = get_layer_gradient(model, layer_name='common_layer')
optimizer.zero_grad()
loss2.backward(retain_graph=True)
gradient_2 = get_layer_gradient(model, layer_name='common_layer')
optimizer.zero_grad()
loss3.backward(retain_graph=True)
gradient_3 = get_layer_gradient(model, layer_name='common_layer')

optimizer.zero_grad()
loss = loss1 + loss2 + loss3
loss.backward()
gradient_4 = get_layer_gradient(model, layer_name='common_layer')

I found that gradient_1 + gradient_2 + gradient_3 != gradient_4. Why?
How can I get the exact three gradients by the three losses, which should add up to gradient_4 ?

Feel free to discuss!

Specifically, by comparing them step by step, I saw that (gradient_1 + gradient_2 + gradient_3).norm() equals to gradient_4.norm(), however, (gradient_1 + gradient_2 + gradient_3).norm() == gradient_4.norm() is False. This is consistent after hundreds of iterations… Is it because of numerical instability? If so, does above code the right way to inspect different gradients by different losses?

Use torch.allclose to compare two float tensors if you think the issue might be due to numerical instability. On printing grad1+grad2+grad3 and grad4, are the values different?

1 Like

Hi, I tried torch.allclose(grad1+grad2+grad3, grad4) and it returns True. That means the two gradient matrices are element-wise equal within a tolerance. The printed values of them are the same but they just do not equal each other.

So, how to explain this? Why the addition of gradient_1 + gradient_2 + gradient_3 is slightly different from using loss = loss1 + loss2 + loss3 to do backpropagation at a time? shouldn’t they differentiate on the same graph three times and accumulate (add) gradients to be gradient_4?

You can only get 6 digits of precision with torch.float32, so if the gradient values are same till 6 digits of precision then it means your observations are a limitation of float32. Maybe adding loss1+loss2+loss3 results in the floating point error. Try printing loss and loss1, loss2, loss3 individually to see if they add up.

You can use torch.set_printoptions(precision=10) to show 10 digits in float.