Confused by autodiff behaviour

I thought I had a good grasp of automatic differentiation but I cannot make sense of the following snippet I wrote:

import torch

def f1(t):
    return t ** 2

def f2(t):
    return t ** 3

def get_t():
    return torch.tensor([1., 2.], requires_grad=True)

# case 1
t = get_t()
print('grad after initialization', t.grad)

y = (f1(t)).norm()
y.backward()
print('case 1 grad ', t.grad)

# case 2
t = get_t()
print('grad after initialization', t.grad)

y = (f1(t) + f2(t).detach()).norm()
y.backward()
print('case 2 grad ', t.grad)

# case 3
t = get_t()
print('grad after initialization', t.grad)

y = (f1(t) + f2(t)).norm()
y.backward()
print('case 3 grad ', t.grad)

With the following output:

grad after initialization None
case 1 grad  tensor([0.4851, 3.8806])
grad after initialization None
case 2 grad  tensor([0.3288, 3.9456])
grad after initialization None
case 3 grad  tensor([ 0.8220, 15.7823])

My understanding is that case 1 and case 2 should have the same gradients, since I detach the value of f2 before adding it to f1. I tried different versions of this code (with torch.no_grad and deleting the intermediate variables) but the values didn’t change. Can someone help with this?

It would depend on the next operations and you will see your expected results if you reduce the values via .sum() or .mean().
To see the influence of a constant on the norm you could write down the calculation and check the gradients manually:

t = get_t()
tmp = f1(t)
#loss = (tmp**2).sum().sqrt() 
loss = ((tmp+100.)**2).sum().sqrt()
loss.backward()
print(t.grad)
1 Like

I was confused about this and I see your point @ptrblck, thanks! As an analogy, this is like to expect that the derivation of (x + 1)^2 and x^2 with respect to x to be the same (1 here plays the role of the detached tensor and ^2 plays the role of the norm operation).

Not quite, since in your example the constant won’t have any effect on x:

x = torch.randn(1, requires_grad=True)

y = (x**2 + 1)
y.backward()
print(x.grad)
# tensor([0.0637])

x.grad = None
y = (x**2)
y.backward()
print(x.grad)
# tensor([0.0637])

However, if you use (x + a)**2 it will since it’s equal to (x**2 + 2xa + a**2):

x.grad = None
y = ((x + 1)**2)
y.backward()
print(x.grad)
# tensor([2.0637])

x.grad = None
y = (x**2 + 2*x + 1)
y.backward()
print(x.grad)
# tensor([2.0637])

Sorry, example was correct in my head but I wrote it wrong! It meant to be (x + 1)^2. :slight_smile:

1 Like