Quick Detach() Question

If I call detach() on a variable after calling backward(retain_variables=True) on a variable further along the graph, will that free the buffers of the graph that led to the detach()'d variable?

Example: net1 and net2 are modules, x is an input variable. I want to retain the state of net2 for processing later down the line, but I just want to accumulate gradients in net1 for the time being, no need to retain the graph.

out1 = net1(x)

out2 = net2(out1)

loss = some_loss_function(out2)

loss.backward(retain_variables=True)

out1.detach_()

Will this clear the buffers of net1?

Thanks,

Andy

1 Like

Hi,

From my understanding it won’t.
One way to do this (maybe not the best) would be to separate these in two graphs and backpropagate the second one with retain_variables and then backpropagate the first one without:

out1 = net1(x)

in2 = Variable(out1.data, requires_grad=True)
out2 = net2(in2)

loss = some_loss_function(out2)

loss.backward(retain_variables=True)
out1.backward(in2.grad.data)

# You can call loss.backward again here
# But not out1.backward !

Hmm, assuming detach() doesn’t free the graph, your way makes a lot of sense, and I think I could wrap that up into this chunk of code neatly. Thanks!

Edit: a slightly more elegant way to do this (still following your same idea) might be to copy out1 to in2 with the out-of-place detach:

in2 = out1.detach()

I’ll try out both and report back.

If in2 is detached, you can’t call loss.backward at all, since net2(in2) are not taken into account by the computation of the gradient.

When a variable is detached, the backward computations will not visit the branches that start from this variable (all the operations done on it).

So as @albanD suggested, you need to create a new graph (hence the new variable requiring a gradient) if you want to compute a gradient with respect to the new network (net2) independently on previous operation (net1).

Both detach() and detach_() will remove the reference to the creator, thus freeing the graph and buffers contained in it.

3 Likes
import torch
x = torch.randn(10, 5, requires_grad = True)
w = x * x
a = w*w

a.sum().backward()
print(x.grad.sum())

w = x * x
a = w * w
w.detach_()

a.sum().backward()
print(x.grad.sum())

in this case, the output is:

tensor(27.7943)
tensor(55.5887)

Why the second backward can backpropagate the gradient to x? detach_() is an in-place operator and the grad_fn of w is changed to None, So from my understanding, x will not receive the second gradient. Do I get something wrong?

It is changing w inplace but it is not changing past use of it. So it still backprops fine because it was used before being detached.

Thanks for replay. But why sometimes inplace operator cause a error in backward and sometimes not?

Which error do you sometimes see and sometimes not?

Ou, I mean there are some inplace operators will cause an error in backward, and some other operators will not. Back to the original question, is there a method that can detach a sub graph after a forward?

For example, there is a graph:
x -> a -> y -> loss
and I call the fast backward
loss -> y -> a -> x.
After that I call
y.backward(grad=some-grad)
and I only want to backward to a:
y -> a.

I know I can solve this by
x -> x.detach() -> a -> y -> loss
and backward the gradient to “a” manually in the fast backward. But is there a better method?

No there isn’t I’m afraid. We don’t allow modifying the graph.

If you just want the backward between two given Tensors, you should use the grad_input = autograd.grad(output, input) API. This will only run backward between the given Tensors.

1 Like