I thought I had a good grasp of automatic differentiation but I cannot make sense of the following snippet I wrote:
import torch
def f1(t):
return t ** 2
def f2(t):
return t ** 3
def get_t():
return torch.tensor([1., 2.], requires_grad=True)
# case 1
t = get_t()
print('grad after initialization', t.grad)
y = (f1(t)).norm()
y.backward()
print('case 1 grad ', t.grad)
# case 2
t = get_t()
print('grad after initialization', t.grad)
y = (f1(t) + f2(t).detach()).norm()
y.backward()
print('case 2 grad ', t.grad)
# case 3
t = get_t()
print('grad after initialization', t.grad)
y = (f1(t) + f2(t)).norm()
y.backward()
print('case 3 grad ', t.grad)
With the following output:
grad after initialization None
case 1 grad tensor([0.4851, 3.8806])
grad after initialization None
case 2 grad tensor([0.3288, 3.9456])
grad after initialization None
case 3 grad tensor([ 0.8220, 15.7823])
My understanding is that case 1 and case 2 should have the same gradients, since I detach the value of f2 before adding it to f1. I tried different versions of this code (with torch.no_grad and deleting the intermediate variables) but the values didn’t change. Can someone help with this?
It would depend on the next operations and you will see your expected results if you reduce the values via .sum() or .mean().
To see the influence of a constant on the norm you could write down the calculation and check the gradients manually:
t = get_t()
tmp = f1(t)
#loss = (tmp**2).sum().sqrt()
loss = ((tmp+100.)**2).sum().sqrt()
loss.backward()
print(t.grad)
I was confused about this and I see your point @ptrblck, thanks! As an analogy, this is like to expect that the derivation of (x + 1)^2 and x^2 with respect to x to be the same (1 here plays the role of the detached tensor and ^2 plays the role of the norm operation).