Inplace operation error when working with lists

Hello.

While working with a project involving an RNN, I’ve ran into an error of:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

Here is a minimum working example to reproduce the error:

# Setup
import torch
torch.autograd.set_detect_anomaly(True)

# Initialization
input_dim = 8
model = torch.nn.Linear(input_dim, 1)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
values = list()
forward_input = torch.rand(1, input_dim)

# Forward
n_iters = 4
for i in range(n_iters):
    value = model(forward_input)  
    values.append(value)

# First iteration works
value = values[0].clone()
loss = value.sum()
optimizer.zero_grad()
loss.backward()
optimizer.step()

# Second iteration fails
value = values[1].clone()
loss = value.sum()
optimizer.zero_grad()
loss.backward()
optimizer.step()

(PS: I know this is a silly example, but it was the simplest way I could reproduce the error :slight_smile: )

Why does this happen? Any help would be GREATLY appreciated. Here is the full traceback:

RuntimeError                              Traceback (most recent call last)
<ipython-input-15-70245d07c321> in <module>
     28 loss = value.sum()
     29 optimizer.zero_grad()
---> 30 loss.backward()
     31 optimizer.step()

~/projects/summarization/summarization-project/lib/python3.7/site-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph)
    196                 products. Defaults to ``False``.
    197         """
--> 198         torch.autograd.backward(self, gradient, retain_graph, create_graph)
    199 
    200     def register_hook(self, hook):

~/projects/summarization/summarization-project/lib/python3.7/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
     98     Variable._execution_engine.run_backward(
     99         tensors, grad_tensors, retain_graph, create_graph,
--> 100         allow_unreachable=True)  # allow_unreachable flag
    101 
    102 

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [8, 1]], which is output 0 of TBackward, is at version 2; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

Environment details:

  • torch==1.5.1
  • torchvision==0.6.1
  • python version 3.7.7

The error is because the weight of the linear layer has changed (through optimizer.step).
One might add that here, you have the gradient computation for addmm (which powers linear) pretend it would also want to compute the input derivative for which it would need the weight which, in this trivial use, is not actually the case. If you have a multiple layers, all but the first need the input gradient and you cannot actually change the weight and then compute the (correct) backward, not even with retain_graph.

Best regards

Thomas