Which is freed, which is not?


#1

Here is a minimum nontrivial example that make me confused:

Example1 : passed

import torch

x = torch.nn.Linear(3,3)
a = torch.autograd.Variable(torch.randn(2, 3))

for i in range(100):
    y = torch.sum(x(a))
    y.backward()

Example2 : failed. RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

import torch

x = torch.nn.Linear(3,3)
a = torch.autograd.Variable(torch.randn(2, 3))
y = torch.sum(x(a))

for i in range(100):
    y.backward()

What confused me is that in principle, the ‘graph’ will be freed after calling y.backward() in the first example and obviously x,a are contained in the graph. Then x, a will be freed too. So I did not expect the first example will pass since during the second y.backward() call we are not able to locate x and a. The second example is even more confusing to me. Just moving y=torch.sum() outside the loop will create an error.

So my question is: in both cases when the ‘graph’ is created? which autograd.Variables/nn.Module are contained in the graph? If the graph is freed after y.backward() is called, why the first example can pass? Why the second one can’t?


(Alban D) #2

Here is your code with some comment on what is in the graph and what is not, let me know if it helps:

import torch

x = torch.nn.Linear(3,3)
# x is an nn.Module
# It contains two "leaf" Variables which mean a Variable that requires gradients

a = torch.autograd.Variable(torch.randn(2, 3))
# a is a Variable that does not require gradient, has no graph associated to
# it since no operation is done.

for i in range(100):
    x_out = x(a)
    # x_out is a `Variable` with the associated graphs of all computations
    # corresponding to the forward function of x
    # Note that all the objects of the graph (Functions and intermediary Variables)
    # are only accessible from the python object x_out 
    y = torch.sum(x_out)
    # y has a graph containing the sum operation and all the graph of x_out
    y.backward()
    # Go through the whole graph associated with y and compute the gradient for
    # all the leafs Variable
    # To reduce memory usage, all the intermediary Variables are freed.
import torch

x = torch.nn.Linear(3,3)
# nn.Module with two leafs Variables
a = torch.autograd.Variable(torch.randn(2, 3))
# a Variable that does not require grads
y = torch.sum(x(a))
# y contains a Variable with associated to it the graph
# corresponding to the forward function of x and 
# the sum Function

for i in range(100):
    y.backward()
    # First iteration:
    # We call backward and free the intermediary Variables of the graph
    # Second iteration:
    # You try to go through the graph associated with y but it has
    # already been cleared
    

#3

Thanks that’s an excellent explanation! As you explained, in the second example when y.backward() was called the second time the variable ‘a’ was freed and the graph was not built again. That’s the reason why we can not let the gradient flow through. But I’m curious about, was ‘x’ also been freed? x is nn.Module so it is a collection of two leaf variables so in principle it is also an ‘intermediate variable’.

Besides, thanks to your comment I found another question that is not obvious to me. Consider the following code:

import torch

x = torch.nn.Linear(3,3)
a = torch.autograd.Variable(torch.randn(2, 3))
x_out = x(a)

for i in range(100):
    y = torch.sum(x_out)
    y.backward()

This code will fail, which is unexpected. I think in the second iteration the graph is recreated because we call y=torch.sum(x_out) , which recreates all the Variables associated with y again. The problem disappears if I move x_out in the loop:

import torch

x = torch.nn.Linear(3,3)
a = torch.autograd.Variable(torch.randn(2, 3))

for i in range(100):
    x_out = x(a)
    y = torch.sum(x_out)
    y.backward()

So my guess is, when you apply an operation like torch.sum it won’t recreate a graph. But x_out = x(a) will recreate a graph. Actually it is just a forward() call of the nn.Module object x. After all these experiences my crude conclusion is: the forward() call of nn.Module will recreate a graph associate with its return value. But this is insufficient…


#4

in your example where you call y.backward() in the for loop, when y.backward() is called, the gradient flows all the way to x. And then the input buffers in x are freed, When you call y = torch.sum(x_out) and then y.backward() the second time, the gradients again try to flow all the way to x, but because x needs it’s input to compute correct gradients (and since input was freed in the first backward() call), it will error out. You can declare to y.backward to not free the graph by saying y.backward(retain_graph=True)


(jdhao) #5

I think the graph associated with the computation is something like this

The input to x is a. You say that

then the input buffers in x are freed

But if we inspect a after y.backward(), a still exists. Also upon looking at source of nn.Linear, input is stored as an attribute of the module. What does it mean that “input buffers in x are freed”?


(Alban D) #6

Any python object can be referenced from many places.
When we say “freed” here, it means that the graph will not reference this object anymore. Of course if the user keeps a reference to the object, it won’t be destroyed as it can be used by someone else. But as soon as you remove your other reference to the object, it will be destroyed.


(jdhao) #7

In your first code snippet, the computational graph is created only once, so when you use for loops and try to backward on the graph for more than once without using retain_graph=True, the error will occur.

In you second code example, each time the for loop is executed, a fresh new graph is created and you can backward() through it once. If you try to backward in the loop for a second time, it will also fail:

import torch

x = torch.nn.Linear(3,3)
a = torch.autograd.Variable(torch.randn(2, 3))

for i in range(100):
    x_out = x(a)
    y = torch.sum(x_out)
    y.backward() # fine
    y.backward() # fail cause you are trying to backprop on the same net for a second time