"Broken" computation graph appears to use stale gradient

Hello Forum!

I tried a test where I expected to get “RuntimeError: Trying to backward
through the graph a second time,” but instead I get results that make no
sense to me.

The basic question: What should happen if you move part of the building
of the computation graph out of the optimization loop?

Here is the test script (with a couple of alternatives commented out):

import torch
print (torch.__version__)

a = torch.arange (9.0).reshape (3, 3)
a.requires_grad = True

opt = torch.optim.SGD ((a,), lr = 0.1)

b = a[:, 1].contiguous()   # a is optimized (with stale gradient) but b remains unchanged

# b = a[:, 1]              # works for first opt.step(), but then fails silently

for  i in range (3):
    # b = a[:, 1]          # works as expected inside the loop
    loss = (b[1] - 10)**2
    opt.zero_grad()
    loss.backward()
    opt.step()
    print ('a[1, 1]: ', a[1, 1])
    print ('b[1]:    ', b[1])

And here is its output:

2.0.1
a[1, 1]:  tensor(5.2000, grad_fn=<SelectBackward0>)
b[1]:     tensor(4., grad_fn=<SelectBackward0>)
a[1, 1]:  tensor(6.4000, grad_fn=<SelectBackward0>)
b[1]:     tensor(4., grad_fn=<SelectBackward0>)
a[1, 1]:  tensor(7.6000, grad_fn=<SelectBackward0>)
b[1]:     tensor(4., grad_fn=<SelectBackward0>)

The first question is why, when .backward() is called the second time
through the loop, does autograd not complain about part of the computation
graph having been freed? I would have thought that the first .backward()
would have freed the connection between b and a (which is not rebuilt inside
of the optimization loop).

Why do we get the results we do? It appears that opt.step() keeps using
the gradient produced by the first .backward() and the subsequent calls
don’t update it (but also don’t issue an error).

(The second – commented out – alternative without the .contiguous()
fails a little differently, but also without an error message. In this case, a
is updated just once (and that change is reflected in b’s view into a) and
then stays constant.)

I see this in versions 2.0.1 and 1.11.0, so it’s not an obvious one-off regression.

[Edit: Adding another example script.]

I’ve added a modified version of the above script that highlights a particular
facet of the issue.

After loss is first computed, it no longer changes when a changes. This
makes sense because b, the link between a and loss is not updated
within the optimization loop; the line b = a[:, 1].contiguous() is only
executed once before the loop begins.

Nonetheless, even though loss no longer depends on a inside of loop, a
nonzero gradient flows back from loss to a. (Note, constant loss, but a
nonzero gradient, is clearly mathematically incorrect.) It’s as if the part of
the computation graph that connects b to a is still in operation, even though
it should have been freed by the call to loss.backward().

Here is a script that prints out a.grad:

import torch
print (torch.__version__)

a = torch.arange (9.0).reshape (3, 3)
a.requires_grad = True

opt = torch.optim.SGD ((a,), lr = 0.1)

b = a[:, 1].contiguous()   # a is optimized (with stale gradient) but b remains unchanged

for  i in range (2):
    loss = (b[1] - 10)**2
    print ('i:', i, ' loss =', loss)         # loss doesn't change (after first computed)
    print ('i:', i, 'before zero_grad()')
    print ('a = ...')
    print (a)
    print ('a.grad = ...')
    print (a.grad)
    opt.zero_grad()
    print ('i:', i, 'after zero_grad(), before backward()')
    print ('a = ...')
    print (a)
    print ('a.grad = ...')
    print (a.grad)
    loss.backward()
    print ('i:', i, 'after backward(), before step()')
    print ('a = ...')
    print (a)
    print ('a.grad = ...')
    print (a.grad)                           # grad is nonzero even though loss no longer depends on a
    opt.step()
    print ('i:', i, 'after step()')
    print ('a = ...')
    print (a)                                # a changes with every step
    print ('a.grad = ...')
    print (a.grad)

And here is its output:

2.0.1
i: 0  loss = tensor(36., grad_fn=<PowBackward0>)
i: 0 before zero_grad()
a = ...
tensor([[0., 1., 2.],
        [3., 4., 5.],
        [6., 7., 8.]], requires_grad=True)
a.grad = ...
None
i: 0 after zero_grad(), before backward()
a = ...
tensor([[0., 1., 2.],
        [3., 4., 5.],
        [6., 7., 8.]], requires_grad=True)
a.grad = ...
None
i: 0 after backward(), before step()
a = ...
tensor([[0., 1., 2.],
        [3., 4., 5.],
        [6., 7., 8.]], requires_grad=True)
a.grad = ...
tensor([[  0.,   0.,   0.],
        [  0., -12.,   0.],
        [  0.,   0.,   0.]])
i: 0 after step()
a = ...
tensor([[0.0000, 1.0000, 2.0000],
        [3.0000, 5.2000, 5.0000],
        [6.0000, 7.0000, 8.0000]], requires_grad=True)
a.grad = ...
tensor([[  0.,   0.,   0.],
        [  0., -12.,   0.],
        [  0.,   0.,   0.]])
i: 1  loss = tensor(36., grad_fn=<PowBackward0>)
i: 1 before zero_grad()
a = ...
tensor([[0.0000, 1.0000, 2.0000],
        [3.0000, 5.2000, 5.0000],
        [6.0000, 7.0000, 8.0000]], requires_grad=True)
a.grad = ...
tensor([[  0.,   0.,   0.],
        [  0., -12.,   0.],
        [  0.,   0.,   0.]])
i: 1 after zero_grad(), before backward()
a = ...
tensor([[0.0000, 1.0000, 2.0000],
        [3.0000, 5.2000, 5.0000],
        [6.0000, 7.0000, 8.0000]], requires_grad=True)
a.grad = ...
None
i: 1 after backward(), before step()
a = ...
tensor([[0.0000, 1.0000, 2.0000],
        [3.0000, 5.2000, 5.0000],
        [6.0000, 7.0000, 8.0000]], requires_grad=True)
a.grad = ...
tensor([[  0.,   0.,   0.],
        [  0., -12.,   0.],
        [  0.,   0.,   0.]])
i: 1 after step()
a = ...
tensor([[0.0000, 1.0000, 2.0000],
        [3.0000, 6.4000, 5.0000],
        [6.0000, 7.0000, 8.0000]], requires_grad=True)
a.grad = ...
tensor([[  0.,   0.,   0.],
        [  0., -12.,   0.],
        [  0.,   0.,   0.]])

Context: I got confused / tripped up by this when exploring what I thought
would be a simple explanation / fix for the issue in the thread:

Thanks for any insight!

K. Frank

2 Likes

I also stumbled across the issue and couldn’t explain it.
CC @albanD for visibility, as I would expect to see an error instead of silent numerical errors.

The first question is why, when .backward() is called the second time
through the loop, does autograd not complain about part of the computation
graph having been freed?

This is because the part of the graph being shared is two ops: one indexing and one contiguous. neither of these saves any Tensor for backward. So nothing is there to be freed during the backward pass (and so we don’t detect the multiple backward calls).
Note that what is being freed during the backward are only the saved Tensors, not the structure of the graph itself. So in practice you can backward as many times as you want through a given graph that doesn’t save any Tensor (but since what gets saved is not guaranteed, you shouldn’t rely on it).

Why do we get the results we do?

b doesn’t share memory with anything that is modified inplace and it is not recomputed so loss remains constant in the loop.
Since the b doesn’t change, the loss doesn’t change and so you get the same gradient at every iteration, it is thus expected that a moves this way.

Hi Alban!

I believe I understand your explanation of the behavior.

A couple of comments:

The example code I posted is, of course, logically incorrect from the
perspective of backpropagation.

So one point of view is: You write broken code, you get bad results.

However, in part because backpropagation errors can be easy to make, hard
to debug, and sometimes hard to even notice, pytorch goes to considerable
lengths to detect and warn about them.

This error is not exactly the same as a “backward through the graph a second
time” error, but it’s at least a kissing cousin. Would it be practical for pytorch
to also detect this subclass of errors?

Thanks for the explanation!

K. Frank

Hey!

Indeed using no_grad() (which is what is done in optimizer.step() ) will hide things from autograd and might get you to compute surprising results.
I don’t think we can make this an error as it would be BC-breaking pretty badly. In particular we don’t want to break existing code that does work today (when you do only a view outside of the main loop for example). But we can definitely add some detection (via a boolean on the Node) and warn_once if you backward again, even if no state if present. This will need to be a brand new mechanism though as the current one for errors will not work (it relies on states being unpacked, but you don’t have state here).

Hi Alban!

Thanks for the follow up.

From where I sit, I don’t really see this as any kind of priority. (Maybe it falls
into the “might be nice” category.)

In fairness, in my example, I was purposely writing unusual “wrong” code.
I haven’t ever tripped over this issue “in real life.”

Thanks again.

K. Frank

1 Like