Understanding garbage collection of computation graphs in pytorch

Hi, I am interested in how unused computational graphs are automatically garbage collected in pytorch. I used to use dynet (another dynamic NN library), where I can use renew_cg() to remove all previously created nodes every time I started creating a new graph for the current training example. However in pytorch, everything seems to be handled automatically, regardless whether I call __call__ of an nn.Module or directly calling the function that implements the computation. Is there any source code/documentation that I can refer to? Thanks!

1 Like

the graphs are freed automatically as soon as the output Variable holding onto the graph goes out of scope. Python implemented refcounting, so the freeing is immediate.

for example:

x = Variable(...)

# Example 1
try:
   y = x ** 2
   z = y * 3
except:
    pass

# graph is freed here
# Example 2
try:
   y = x ** 2
   z = y * 3
   z.backward(...) # graph is freed here
except:
    pass
# Example 3
try:
   y = x ** 2
   z = y * 3
   z.backward(..., retain_variables=True)
except:
    pass
# graph is freed here
1 Like

The fundamental difference is that DyNet’s graphs are global objects held by the singleton autograd engine. In PyTorch (and Chainer) the graphs are attached to the variables that are involved in them and (as Soumith demonstrated) go out of scope when those variables do.

1 Like