Which is freed, which is not?

Here is your code with some comment on what is in the graph and what is not, let me know if it helps:

import torch

x = torch.nn.Linear(3,3)
# x is an nn.Module
# It contains two "leaf" Variables which mean a Variable that requires gradients

a = torch.autograd.Variable(torch.randn(2, 3))
# a is a Variable that does not require gradient, has no graph associated to
# it since no operation is done.

for i in range(100):
    x_out = x(a)
    # x_out is a `Variable` with the associated graphs of all computations
    # corresponding to the forward function of x
    # Note that all the objects of the graph (Functions and intermediary Variables)
    # are only accessible from the python object x_out 
    y = torch.sum(x_out)
    # y has a graph containing the sum operation and all the graph of x_out
    y.backward()
    # Go through the whole graph associated with y and compute the gradient for
    # all the leafs Variable
    # To reduce memory usage, all the intermediary Variables are freed.
import torch

x = torch.nn.Linear(3,3)
# nn.Module with two leafs Variables
a = torch.autograd.Variable(torch.randn(2, 3))
# a Variable that does not require grads
y = torch.sum(x(a))
# y contains a Variable with associated to it the graph
# corresponding to the forward function of x and 
# the sum Function

for i in range(100):
    y.backward()
    # First iteration:
    # We call backward and free the intermediary Variables of the graph
    # Second iteration:
    # You try to go through the graph associated with y but it has
    # already been cleared
    
2 Likes