The code you provided works fine on my machine. Make sure you have the newest version of PyTorch.
There is an error if you change retain_graph
to False
in the backward()
call and in that case I think this thread will be helpful.
In PyTorch, after the backward
call, it is not allowed to call the backward
again on the same graph (or part of it) unless it is a leaf node of that graph (because the backward pass frees the buffers stored by the autograd
engine). next_pos
is a part of the graph that is build in the first iteration. By detaching it (so it’s no longer a part of the graph), next_pos
becomes a leaf node and everything works fine. Otherwise, the error is raised since next_pos
was used in the previous backward
call.
Here is a snippet that makes the erorr more obvious
# dummy forward pass (the engine records it)
a = torch.tensor(10., requires_grad=True)
b = a * 3
c = b - 1.5
# now it works
c.backward()
# this will work
# (a + b.detach()).backward()
# error
(a + b).backward()