Backpropagation - Graph is reused while it shouldn't

Consider the following (made up) example snippet:

import torch

measure = torch.nn.MSELoss()

x = torch.tensor([1, 0], dtype=torch.float64)
t = torch.eye(2, dtype=torch.float64, requires_grad=True)
a = torch.ones(2, dtype=torch.float64, requires_grad=True)
y = t @ a

for __ in range(2):
    x_out = t @ (x + y)  # Raises `RuntimeError` later on.
    # x_out = t @ (x + t @ a)  # Works fine.
    loss = measure(x_out, x)

During the second iteration of the for loop the loss.backward() raises the following exception:

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

I am not sure why this happens since in every iteration I create a new computational graph by redefining x_out and loss. What is reused in every iteration is the y tensor and when I replace y with its original expression t @ a then the code runs without error.

So I suppose that by reusing y among the iterations that part of the graph is shared and reused. However the underlying mechanism is not really clear to me. I thought that the first call to loss.backward() should detach the graph and so subsequent iterations should build a new one? But apparently the sub-graph y = t @ a is reused in the following iterations (which makes sense since I don’t provide information about that part anymore). So for the above example, which graph is actually being built and what parts of it are freed again upon calling loss.backward()?

My second question is, In order to make the above example work without recomputing t @ a at every iteration, what is the preferred way of dealing with the problem? Should I specify retain_graph=True? However I don’t want to retain the full graph, since the “tail” is rebuilt on every iteration. Also I read in this topic that a graph will be freed when the corresponding output variables run out of scope (when their reference count drops to zero) and since for loss that happens at every iteration I would expect the whole graph to be rebuilt on every iteration. According to the error message that doesn’t seem the case though. The y variable never runs out of scope, so does that mean that the sub-graph corresponding to y = t @ a is not freed? And how can I free this (sub-) graph manually then (after the loop)?

My third question is if someone could explain the details about graph creation and graph (buffer) freeing. According to the error message torch is aware of the fact that I want to reuse a (part of the) graph however it already freed the relevant resources. So what information does torch use in order to evaluate the structure of graphs (I suppose it’s .grad_fn) and what resources are allocated and then freed during backpropagation?


There are many questions here, I’ll try and answer in the order you asked them even if they overlap.

Yes the error is because here the operation t @ a is reused multiple times. I am not sure what you mean by " loss.backward() should detach the graph". backward does not change the graph, it just traverses it (and free un necessary buffers if retain_graph is not set).
In you case, the graph up to y is built once. Then each iteration extend this graph. The problem is that the first backward will free even the common part of the graph and so the second one will fail.

“I would expect the whole graph to be rebuilt on every iteration”: keep in mind that the computational graph is built ONLY buy doing forward operations in your python script. The backend will never change/built it.
In you case you indeed can use retain_graph=True to keep the whole graph during the backward. And the “lower” part (that was created inside the loop) will be freed when the loop exit. The upper part (t @ a) will still remain as the y variable still refers to it.
To free the upper part after the loop, simply delete everything that references it: del y in your case.

The graph is created every time your perform an operation on tensors (when gradient mode is enabled).
The graph contains both the functions that need to be applied to perform the backward pass and the buffers and temporary variable they need.
When you call backward, the graph is traversed in reverse order and each function is applied one after the other. After each function execution, all its buffers and temporary variables are deleted (if retain_graph=False).
If during the backward, a buffer does not exist, then the above error is raised as the only way this buffer does not exist anymore is because your already backproped.

1 Like


thanks for the prompt and thorough reply, it’s a lot clearer now :slight_smile: I have a few follow-up questions though:

  1. So that basically means that y = t @ a creates a graph and each iteration in the loop extends this graph by attaching to it? Consequently when the variables in the loop are re-assigned (x_out and loss) then the previously attached part is removed again, since nothing refers to it anymore? Is that interpretation correct?

  2. Then the keyword argument retain_graph actually only refers to the buffers (of the graph’s vertices) that are needed for performing the backpropagation? But the actual graph (i.e. the relations between tensors) are retained in any way (no matter the value of the retain_graph argument)? In other topics I read that it was called retain_variables before, why the change to retain_graph (retain_buffers would be clearer maybe)? At least to me it was confusing until I read your reply.

  3. The following statements are not completely clear to me:

The graph contains both the functions that need to be applied to perform the backward pass […]
[…] each function is applied one after the other

What are these functions actually (are they related to the .grad_fn attribute)? And how are the backprop buffers related to them?

  1. Yes, here is a small sample. I use y = a*2 here because that way the op has a single input and single output and it is easier to show the graph.
a = torch.rand(10, requires_grad=True)
# no graph at this point

y = a * 2
# The backward graph linked to y is: y -> a

## Iterations 0
x_out0 = 2 * y
# The backward graph linked to x_out0 is: x_out0 -> y -> a
# Note that the backward graph linked to y is still: y -> a (the arrow here is the same as the one between y and a above, meaning that deleting the buffers in one will delete them in the other)

loss0 = measure(x_out0)
# The backward graph linked to loss0 is: loss0 -> x_out0 -> y -> a

## Iteration 1
x_out1 = 2 * y
# The backward graph linked to x_out1 is: x_out1 -> y -> a (the arrow between y and a here is the same as the one between y and a in iteration 0, meaning that deleting the buffers in one will delete them in the other)

loss1 = measure(x_out1)
# The backward graph linked to loss1 is: loss1 -> x_out1 -> y -> a
# This one will fail as the buffers for the array between y and a are already deleted.

# If you were printing the all the graphs linked to both loss0 and loss1, you would get:
# loss1 -> x_out1 -> y -> a
#    loss0 -> x_out0 ↑
# As you can see some part of the graph is shared.
  1. It refers to both buffers and intermediary results. Without these, the “graph” cannot be used for anything and so it does not really exist anymore. We only use it to detect double usage.

  2. These functions are autograd.Function that contain both the code to run for the backward pass and some intermediate variables needed to perform these computations. Each function handle it’s buffers and temporary variables in the way it wants.

Does that make things clearer?
Let me know if you have other doubts.

Your 2. point is still confusing me. You say that

[…] and intermediary results. Without these, the “graph” cannot be used for anything […]

To what parts of the graph does this refer? Does it refer to the entire graph that was tracked through backpropagation? And what does “anything” mean in that context? What else than gradient computation would / could you use a graph for?

I tried the following snippet where a graph is constructed (up to c), then backproped and then the graph’s upper part (by d, or even the whole graph by e) is reused for construction of another graph. Based on the above statement that .backward() frees any resources / buffers / intermediary results of the graph, I would expect the computation of d and e not to work (i.e. not only backprop), but apparently it does:

a = torch.tensor(1, requires_grad=True)
b = a * torch.tensor(2)
c = b * torch.tensor(3)


d = b * torch.tensor(4)
e = c * torch.tensor(5)
print('d = ', d)
print('e = ', e)

So it still remains unclear to me what things, after performing .backward(), the graph (or parts of it) can be used for and what things it cannot be used for anymore (another .backward() for example). Like a list of do’s and don’t’s after a graph has been backproped (considering operations involving that graph).

This refers to every part of the graph that were visited during the backward pass.

Yes only gradient computation in this case.
Be careful if playing with super simple example, you can create simple graphs that don’t have any buffer or intermediary result and so you will be able to backprop through them multiple times. But this is an artefact of the fact that these graphs are super simple.

It does free ressources of the graph. Not the Tensors that the user created during the forward.
You don’t have a strong link between Tensors from the forward pass and nodes in the graph. The only thing is that computing a new tensor creates a new node in the graph. But you can remove the Tensor without impacting the graph and you can remove the graph’s buffers (by doing backward) without impacting the Tensor that created it.

In your example, b and c are not part of the graph.
So even after calling backprop, you can still use them.
The same way, you could del b and the c.backward() call will still work without any problem.

In your example, b and c are not part of the graph.

HI,I have a question that how to know if the b and c are not part of graph?

Because the graph we were talking about above is the graph to compute gradients. Which is composed only of backward Nodes and their buffers, not the user Tensors.

Because the graph we were talking about above is the graph to compute gradients.

I still have a hard time getting my head around this … So I write a easy example as below and hope can get some feedback from you!

a = torch.tensor(1., requires_grad=True)
b = torch.tensor(2.,requires_grad= True)
c = a*b 

To my understanding , the graph would be created since c =a*b

for _ in range(5):

The result

tensor(2.) tensor(1.)
tensor(4.) tensor(2.)
tensor(6.) tensor(3.)
tensor(8.) tensor(4.)
tensor(10.) tensor(5.)

Why does print(c.backward(retain_graph=True))gives the None??

As for the explanation you give above:
If the tensor is not part of computing gradient graph ,then we can delete it and backward should work fine .So,here’s what I do:

del b 
for _ in range(3):
     c.backward(retain_graph = True)

And the result gives :


The result does’t make sense to me since to my understanding, c =a*b
What I expect is that the c.backward(retain_graph = True) should fail because the b has already been deleted ,what do I miss here?

I think the problem is that I do not know which tensor or function would be considered as the graph to compute gradients,Maybe b here is not ,but I think it should be part of it…

Two questions for you ,I am a beginner of Pytorch and thanks in advance!

Hi, I read this blog here:

What I got is:
1/ when you want a tensor to be part of the graph, you need to specify: requires_grad =True
So, for your first example, x,t,a, y and x_out (requires_grad being contagious) are all going to be in the dynamic computation graph once you do the forward pass (ie y = t @ a also x_out=t @ (x + y) and loss = measure())
If we define a leaf node as a tensor you create without the need of a function( say add, multiply etc) then t and a are leaf nodes
2/All the non-leafs nodes will be erased once you do the backward pass. So, y and x_out’s values will be erased from the graph as well as the other gradients say dy/dt , dy/da, dx_out/dt, dx_out/dy etc
3/For there to not be an error, one has to do forward() then backward(). But since you only recalculated x_out and loss but not y again in the loop, you didn’t do a complete forward(). Thus the error happened.
Hope this helps!


  • .backward() does not return anything. So it prints None. You can check the doc for more details.
  • I mention ackward Nodes and their buffers. In your case, the content of b becomes one of the buffers needed to backward the multiplication. So even if you delete b itself. The corresponding buffer will still be there.
1 Like

Ok,Thanks for your clear explanation !