Backward through gradient update

I am trying to figure out why this minimal example is not working as expected:

# two variables
x = torch.Tensor([1])
y = torch.Tensor([1])
x.requires_grad = True
y.requires_grad = True

# loss (think: training loss)
loss = x * y
print("loss", loss)
print("x", x, x.grad)
print("y", y, y.grad)

# update x using gradient (which depends on the value of y)
new_x = x - 1e-1 * x.grad
print("new_x", new_x)

# different loss using updated value (think: validation loss)
new_loss = new_x ** 2
print("new_loss", new_loss)  # clear grads from first pass  # clear grads from first pass
new_loss.backward()  # compute gradient for new loss

# Why is grad of y zero? 
print("x", x, x.grad)
print("y", y, y.grad)


loss tensor([1.], grad_fn=<MulBackward0>)
x tensor([1.], requires_grad=True) tensor([1.])
y tensor([1.], requires_grad=True) tensor([1.])
new_x tensor([0.9000], grad_fn=<SubBackward0>)
new_loss tensor([0.8100], grad_fn=<PowBackward0>)
x tensor([1.], requires_grad=True) tensor([1.8000])
y tensor([1.], requires_grad=True) tensor([0.])

Changing the value of y does change the value of new_loss, so shouldn’t it’s gradient be non-zero? What am I missing?

OK, adding create_graph=True to the first backward call fixes the problem. From the docs:

create_graph (bool, optional) – If True, graph of the derivative will be constructed, allowing to compute higher order derivative products. Defaults to False.

I kind of expected a warning or error, similar to when calling backward twice (without explicitly retaining the graph), but I guess it might be hard to detect if cutting the computational graph at that point is intended or not…

1 Like