Runtime Error in .backward()

Hi! The following code:

import torch
h_ = torch.nn.Parameter(torch.ones([1,1]))
h = h_**2
optimizer = torch.optim.SGD([h_], lr=0.1)

for _ in range(2):
  loss = h.sum()

throws a

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

Can anyone tell me whats the issue? Thanks!

Try this code

import torch

h_ = nn.Parameter(torch.ones([1,1]),requires_grad=True)
optimizer = torch.optim.SGD([h_], lr=0.01)

for _ in range(10):
    hOut=h_**2 - h_*10
    loss = hOut.sum()

What I have done is inserted the calculation of hOut within the loop itself. Let us take a look at the changes in the graphical manner

Since you were not recalculating H within each loop, it freed the buffer after the first LOSS Bakcward. The next loop it threw and error
We included the the recalculation of H within the loop so that the buffer is refilled. Hope this helps.

1 Like

I see. That fixes the problem. Thanks alot! :slight_smile: