RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time

I encounter this situation too.
My training code is as below:


class testNet(nn.Module):
  def __init__(self):
      super(testNet, self).__init__()
      self.layers = nn.ModuleList([])
      self.posemb = nn.Linear(2, 40)
      for i in range(3):
          self.layers.append(nn.Linear(2, 40))

      self.bias = None

  def forward(self, x, idx):
      if self.bias is None:
          self.bias = self.posemb(x).reshape(-1, 20, 2)
      pred = self.layers[idx](x)
      pred = pred.reshape(-1, 20, 2) + self.bias
      return pred

if __name__ == '__main__':

  net = testNet()
  for b in range(3):
      x = torch.rand(4, 2)
      label = torch.rand(4, 20, 2)

      ll = []
      for i in range(3):
          pred = net(x, i)
          x = pred[:, 0, :]
          ll.append(torch.norm(pred - label, p=-1).mean())

      loss = torch.stack(ll).mean()
      loss.backward()
      print('batch: %d | loss: %f' % (b, loss.item()))

In the first train loop, everything is ok, but in the second loop, it will give me this error.
I have checked for a long time, and finally find out the problem is in my network.
In my forward function, I will save a intermedium variable to save extra calculation.
In the second loop, this variable should to be calculate again, but my judgement condition skip the recalculate process, so the grad graph of this intermedium variable is being cleared.
After I change

if self.bias is None:
        self.bias = self.posemb(x).reshape(-1, 20, 2)

to

if idx == 0:
        self.bias = self.posemb(x).reshape(-1, 20, 2)

the problem is solved!