I stumbled upon the following error
RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time.
in a rather convoluted RNN snippet. I tried to boil the snippet down as much as possible and hope that the error in the snippet below is triggered by a similar reason:
import torch
import torch.nn as nn
class Cell(nn.Module):
def __init__(self):
super(Cell, self).__init__()
self.state = torch.tensor([[0., ]])
self.weight = nn.Parameter(torch.tensor(0.), requires_grad=True)
def forward(self) -> torch.Tensor:
x = torch.zeros((1, 1))
x[0, 0] = self.weight
res = self.state + x
# res = self.state + self.weight
self.state = x
return res
if __name__ == '__main__':
cell = Cell()
def _train_once(cell, optimizer):
optimizer.zero_grad()
y = cell()
loss = torch.sum(y)
loss.backward()
optimizer.step()
optimizer = torch.optim.SGD(cell.parameters(), lr=1e-3)
_train_once(cell, optimizer)
_train_once(cell, optimizer)
Can you explain to me how to fix the error (without setting retain_graph=True
)? I also don’t understand why the second backpropagation works when replacing res = self.state + x
with res = self.state + self.weight
…
Thanks for your help!