"RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time" while using custom loss function

I keep running into this error:

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

I’m searching in forum, but still can’t know what I have wrong in my custom loss function.
I’m using nn.GRU, here is my Loss function:

def _loss(outputs, session, items):  # `items` is a dict() contains embedding of all items
    def f(output, target):
        pos = torch.from_numpy(np.array([items[target["click"]]])).float()
        neg = torch.from_numpy(np.array([items[idx] for idx in target["suggest_list"] if idx != target["click"]])).float()
        if USE_CUDA:
            pos, neg = pos.cuda(), neg.cuda()
        pos, neg = Variable(pos), Variable(neg)

        pos = F.cosine_similarity(output, pos)
        if neg.size()[0] == 0:
            return torch.mean(F.logsigmoid(pos))
        neg = F.cosine_similarity(output.expand_as(neg), neg)

        return torch.mean(F.logsigmoid(pos - neg))

    loss = map(f, outputs, session)
return -torch.mean(torch.cat(loss))

Training code:

        # zero the parameter gradients
        model.zero_grad()

        # forward + backward + optimize
        outputs, hidden = model(inputs, hidden)
        loss = _loss(outputs, session, items)
        acc_loss += loss.data[0]

        loss.backward()
        # Add parameters' gradients to their values, multiplied by learning rate
        for p in model.parameters():
            p.data.add_(-learning_rate, p.grad.data)
3 Likes

I don’t think error is in your loss function. I think any loss function would cause this error.

Am I right in saying that your training loop doesn’t detach or repackage the hidden state in between batches? If so, then loss.backward() is trying to back-propagate all the way through to the start of time, which works for the first batch but not for the second because the graph for the first batch has been discarded.

If I am right then there are two possible solutions.

  1. detach/repackage the hidden state in between batches. There are (at least) three ways to do this.

    1. hidden.detach_()
    2. hidden = hidden.detach()
    3. hidden = Variable(hidden.data, requires_grad=True)
  2. replace loss.backward() with loss.backward(retain_graph=True) but know that each successive batch will take more time than the previous one because it will have to back-propagate all the way through to the start of the first batch.

24 Likes

Thanks. I have tried 2nd solution before, but it ran really slow. Your 1st solution helped me to fix it.

3 Likes

@jpeg729, can you elaborate a bit more:

Mathematically speaking, what would be the difference between solution 1 and solution 2? Or they are mathematically equivalent but not as same computationally efficient ?

Thanks

1 Like

The same question~Thank you~

@jpeg729 Thank you for your answer. I am adding this comment which can further help and elaborate on the statement "Am I right in saying that your training loop doesn’t detach or repackage the hidden state in between batches?"

Detach is very well explained but repackage is what caught my eyes and wanted an explanation. I guess I figured that out and which in my understanding is as follows:

>>> a = torch.tensor([1,2,3.], requires_grad = True)
>>> out = a.sigmoid()
>>> out.sum().backward()
>>> a.grad # This will output following
tensor([0.1966, 0.1050, 0.0452])

But when we run the following code again :

>>> out.sum().backward() # when we run this again then following runtime error is in output:
[RuntimeError]: Trying to backward through the graph a second time, 
but the buffers have already been freed. 
Specify retain_graph=True when calling backward the first time.

This error because the graph has been cleared as indicated and hence can not backpropagate. Thus, if we want to backpropagate again then we must use <retrain_graph=True> or create the graph again (which is like calling the forward function of the model to create the graph again) by using the following:

>>> out = a.sigmoid()
>>> out.sum().backward()
>>> a.grad # This will output following. One extra thing to notice here is 
# that gradient has been added over the last one as we did not perform
# < a.grad.data.zero_()> to reset the gradient thus is accumulated. 
tensor([0.3932, 0.2100, 0.0904])

I hope this further explains and make thing more clear.

Please comment on this if something is not fine.

8 Likes

Could you elaborate on the line hidden = Variable(hidden.data, requires_grad=True)? I don’t understand what Variable is in the code, but I have the same problem in my own network using a GRU. I think your solution will work, but I don’t understand how to apply that specific line. Thanks!

Would be great if they included this or a link to this in the GRU/RNN docs. Judging by how many have liked your answer, this is pertinent info.

Variable is a specific Pytorch class, which has been deprecated. I think they just use only Tensor now.
https://pytorch.org/docs/stable/autograd.html#variable-deprecated

For example:
hidden = torch.Tensor(hidden.data, requires_grad=True)

Hello, hoping for some help as I have been stuck for a long time. This is a snippet of the code:

optimiser.zero_grad()
loss.backward()
for param in agent.parameters(): param.grad.data.clamp_(-1, 1) 
optimiser.step()
for param in agent.parameters(): param = param.detach()

As you can see, I am already ‘repackaging’ the model parameters. For context, I am trying to implement DQN, following as closely to the PyTorch tutorial as possible (different environment and use case).

I randomly get one of two errors on each run:

“one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [3, 1]], which is output 0 of TBackward, is at version 2; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!”

or

“Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling .backward() or autograd.grad() the first time.”

Any suggestions as to what the bug might be?

Also, print(param._version) at the very beginning of the training loop (first line), some of the parameters (weights) are already at version 1, whilst others (biases) are at version 0. Is that normal?

This helped. I was consuming too much memory on GPU w/ retain_graph=True parameter