Backward through the graph a second time error

I stumbled upon the following error

RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time.

in a rather convoluted RNN snippet. I tried to boil the snippet down as much as possible and hope that the error in the snippet below is triggered by a similar reason:

import torch
import torch.nn as nn


class Cell(nn.Module):
    def __init__(self):
        super(Cell, self).__init__()
        self.state = torch.tensor([[0., ]])
        self.weight = nn.Parameter(torch.tensor(0.), requires_grad=True)

    def forward(self) -> torch.Tensor:
        x = torch.zeros((1, 1))
        x[0, 0] = self.weight

        res = self.state + x
        # res = self.state + self.weight

        self.state = x

        return res


if __name__ == '__main__':
    cell = Cell()

    def _train_once(cell, optimizer):
        optimizer.zero_grad()

        y = cell()
        loss = torch.sum(y)
        loss.backward()

        optimizer.step()

    optimizer = torch.optim.SGD(cell.parameters(), lr=1e-3)
    _train_once(cell, optimizer)
    _train_once(cell, optimizer)

Can you explain to me how to fix the error (without setting retain_graph=True)? I also don’t understand why the second backpropagation works when replacing res = self.state + x with res = self.state + self.weight

Thanks for your help!

Hi,

The problem is that your self.state is kept across iterations. And so the computational graph for the second call actually includes the one from the first call.
You mist likely want to either reset your state. Or .detach() to avoid doing so.

1 Like

Great. Thanks for your help! I indeed want to keep the state since the context of this cell is an RNN. Do you know why this isn’t an issue when replacing res = self.state + x with res = self.state + self.weight (i.e, the comment in the snippet)?

No the two should give the same result.

The problem is that they don’t :see_no_evil:

Could you share a runable code sample that shows the difference please?

1 Like

Thanking you again. Just swap the comment in my first post (inside the forward function). These are the lines I am referring to.

The difference is just that when you use the inplace op, you create a more complex graph and so backward through it twice is actually an issue.

Both cases are wrong as I mentioned above as in both cases, the second call backprop into the graph created by the first one.
It just happens that in the case where you use weights directly, the graph of the first one is trivial and since no buffers are saved, you can actually backprop again into it.
When you use the inplace op in the second case, the graph is not trivial anymore and it actually raises an error.

1 Like