Pytorch seems to only aggregate (accumulate) gradients on the same layer.
gradient_1 = get_layer_gradient(model, layer_name='common_layer')
gradient_2 = get_layer_gradient(model, layer_name='common_layer')
gradient_3 = get_layer_gradient(model, layer_name='common_layer')
loss = loss1 + loss2 + loss3
gradient_4 = get_layer_gradient(model, layer_name='common_layer')
I found that
gradient_1 + gradient_2 + gradient_3 != gradient_4. Why?
How can I get the exact three gradients by the three losses, which should add up to
Feel free to discuss!
Specifically, by comparing them step by step, I saw that
(gradient_1 + gradient_2 + gradient_3).norm() equals to
(gradient_1 + gradient_2 + gradient_3).norm() == gradient_4.norm() is False. This is consistent after hundreds of iterations… Is it because of numerical instability? If so, does above code the right way to inspect different gradients by different losses?
torch.allclose to compare two float tensors if you think the issue might be due to numerical instability. On printing
grad4, are the values different?
Hi, I tried
torch.allclose(grad1+grad2+grad3, grad4) and it returns
True. That means the two gradient matrices are element-wise equal within a tolerance. The printed values of them are the same but they just do not equal each other.
So, how to explain this? Why the addition of
gradient_1 + gradient_2 + gradient_3 is slightly different from using
loss = loss1 + loss2 + loss3 to do backpropagation at a time? shouldn’t they differentiate on the same graph three times and accumulate (add) gradients to be
You can only get 6 digits of precision with
torch.float32, so if the gradient values are same till 6 digits of precision then it means your observations are a limitation of
float32. Maybe adding
loss1+loss2+loss3 results in the floating point error. Try printing
loss1, loss2, loss3 individually to see if they add up.
You can use
torch.set_printoptions(precision=10) to show 10 digits in float.