Consider the following (made up) example snippet:
import torch
measure = torch.nn.MSELoss()
x = torch.tensor([1, 0], dtype=torch.float64)
t = torch.eye(2, dtype=torch.float64, requires_grad=True)
a = torch.ones(2, dtype=torch.float64, requires_grad=True)
y = t @ a
for __ in range(2):
x_out = t @ (x + y) # Raises `RuntimeError` later on.
# x_out = t @ (x + t @ a) # Works fine.
loss = measure(x_out, x)
loss.backward()
During the second iteration of the for
loop the loss.backward()
raises the following exception:
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
I am not sure why this happens since in every iteration I create a new computational graph by redefining x_out
and loss
. What is reused in every iteration is the y
tensor and when I replace y
with its original expression t @ a
then the code runs without error.
So I suppose that by reusing y
among the iterations that part of the graph is shared and reused. However the underlying mechanism is not really clear to me. I thought that the first call to loss.backward()
should detach the graph and so subsequent iterations should build a new one? But apparently the sub-graph y = t @ a
is reused in the following iterations (which makes sense since I don’t provide information about that part anymore). So for the above example, which graph is actually being built and what parts of it are freed again upon calling loss.backward()
?
My second question is, In order to make the above example work without recomputing t @ a
at every iteration, what is the preferred way of dealing with the problem? Should I specify retain_graph=True
? However I don’t want to retain the full graph, since the “tail” is rebuilt on every iteration. Also I read in this topic that a graph will be freed when the corresponding output variables run out of scope (when their reference count drops to zero) and since for loss
that happens at every iteration I would expect the whole graph to be rebuilt on every iteration. According to the error message that doesn’t seem the case though. The y
variable never runs out of scope, so does that mean that the sub-graph corresponding to y = t @ a
is not freed? And how can I free this (sub-) graph manually then (after the loop)?
My third question is if someone could explain the details about graph creation and graph (buffer) freeing. According to the error message torch is aware of the fact that I want to reuse a (part of the) graph however it already freed the relevant resources. So what information does torch use in order to evaluate the structure of graphs (I suppose it’s .grad_fn
) and what resources are allocated and then freed during backpropagation?