Inplace operation breaks computation graph

Hi! I am trying to create a model that makes predictions and updates the current state based on this. However, when I try keep track of the state and updating this, the computation graph seems to break. How do I address this issue/is there something I’m doing incorrectly in my code:

import torch
import torch.nn as nn
from copy import deepcopy
import math
from torch.nn import Linear, ModuleList, Sequential, ReLU
torch.manual_seed(0)
import numpy as np

class FeedForward():
    def __init__(self, input_size, hidden_size):
        self.input_size = input_size
        self.hidden_size  = hidden_size
        self.mlp = Sequential(
            torch.nn.Linear(self.input_size, self.hidden_size),
            torch.nn.ReLU(),
            torch.nn.Linear(self.hidden_size, 1),
            torch.nn.Sigmoid()
        )

    def forward(self, x):
        output = self.mlp(x)
        return output

model = FeedForward(8,32)

z = torch.tensor([math.pi/2.3], dtype=torch.float64)
curr_state = torch.tensor([[0, 1, z]]) #state variables, (for ex pos, vel, accel)
curr_state = curr_state.repeat(5,1) #batch size of 5, all with same curr_state

opt = torch.optim.Adam(model.mlp.parameters(), lr=1e-3)

next_pos = curr_state[:,0] #assume this is current_pos
for ep in range(10):
    with torch.autograd.set_detect_anomaly(True):
        print(ep)
        opt.zero_grad()
        x = torch.rand(5, 8) 
        y = torch.rand(5, 1)

        delta_pos = model.mlp(x)
        #next_pos = next_pos.detach() --- note adding this fixes this issue. why is this needed? will this solve the issue?
        next_pos = next_pos + delta_pos

        loss = torch.sum((next_pos-y)**2)
        print(loss)
        loss.backward(retain_graph=True) 
        opt.step()

Thanks!

The code you provided works fine on my machine. Make sure you have the newest version of PyTorch.

There is an error if you change retain_graph to False in the backward() call and in that case I think this thread will be helpful.

In PyTorch, after the backward call, it is not allowed to call the backward again on the same graph (or part of it) unless it is a leaf node of that graph (because the backward pass frees the buffers stored by the autograd engine). next_pos is a part of the graph that is build in the first iteration. By detaching it (so it’s no longer a part of the graph), next_pos becomes a leaf node and everything works fine. Otherwise, the error is raised since next_pos was used in the previous backward call.

Here is a snippet that makes the erorr more obvious

# dummy forward pass (the engine records it)
a = torch.tensor(10., requires_grad=True)
b = a * 3
c = b - 1.5

# now it works 
c.backward()

# this will work
# (a + b.detach()).backward()

# error
(a + b).backward()