Variable modified by an inplace operation error although using non-inplace operation

Hi, the following piece of code:

import torch
x=torch.tensor([-1.2, 0.6], requires_grad=True)
y=1+x
y[0]=torch.abs(y[0])
y.sum().backward()

gives error that says:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor []], which is output 0 of AsStridedBackward0, is at version 1; expected version 0 instead.

However, if you change the fourth line to y[0]=torch.abs_(y[0]) everything works again. But according to the doc, abs_ is inplace version of abs. How come it works while abs won’t?

Let’s look that into step by step.

You should know what torch.abs operation returns.

a1 = torch.rand(2)
a1.requires_grad_(True)
b1 = a1 + 1
c1 = torch.abs(b1[0])
print(b1[0].data_ptr(), c1.data_ptr()) # They are the different

a2 = torch.rand(2)
a2.requires_grad_(True)
b2 = a2 + 1
c2 = torch.abs_(b2[0])
print(b2[0].data_ptr(), c2.data_ptr()) # They are the same

Copy the above code and see how those two printed values different.
in-place returns exactly the same tensor (value, address)
Consequently, c2 = torch.abs_(b2[0]) is ignored because they are pointing the same data(address).
So, c2 = ~~~ (in-place operation) doesn’t occur at all.

Thanks for your reply. Is there a way to avoid such “variable modified by an inplace operation” error? I encounter this error when I’m implementing YOLO v1, recall they take sqrt of width and height in calculating loss:

import torch
x=torch.randn(2,4, requires_grad=True) * 1
x[:,2:]=torch.sqrt(torch.abs(x[:,2:])+1e-8) * torch.sign(x[:,2:]) # take sqrt of w and h
y=x.sum()
y.backward()

This results in error. what may be a good way to solve it?

How about separate them?

a = x[:, :2]
b = x[:, 2:]
b = ~~~
c = torch.cat(a, b, dim=1)
c = c.sum()
c.backward()

I’m not sure it works

Hi Sam (and Suho)!

This requires a few words of explanation.

The short story is that abs() requires an unmodified copy of its input
tensor in order to compute its gradient during backward() (as does
abs_()). abs() relies on its input still being available in unmodified
form during backward, while abs_() knows that it is modifying its input
inplace so it stores an unmodified copy of its input in the computation
graph for use during backward().

When you switch from abs() to abs_() you cause an unmodified copy
of y to be stored, thus fixing the inplace modification error.

(Reminder: Modifying a tensor inplace is not necessarily an error. It
depends on whether that tensor is needed during the backward pass.)

A tweaked version of your example starts to show what is going on:

>>> import torch
>>> print (torch.__version__)
2.0.0
>>>
>>> # using abs_ without assignment
>>> x = torch.tensor ([-1.2, -2.6], requires_grad = True)
>>> y = x + 1
>>> y, y._version, y.data_ptr()   # initial, unmodified version of y
(tensor([-0.2000, -1.6000], grad_fn=<AddBackward0>), 0, 1833055763904)
>>> torch.abs_ (y[0])
tensor(0.2000, grad_fn=<AsStridedBackward0>)
>>> y *= 1
>>> y, y._version, y.data_ptr()   # y has been modified inplace once
(tensor([ 0.2000, -1.6000], grad_fn=<MulBackward0>), 2, 1833055763904)
>>> y.sum().backward()            # backward() works
>>> x.grad
tensor([-1.,  1.])
>>>
>>> # using abs_ with assignment
>>> x = torch.tensor ([-1.2, -2.6], requires_grad = True)
>>> y = x + 1
>>> y, y._version, y.data_ptr()   # initial, unmodified version of y
(tensor([-0.2000, -1.6000], grad_fn=<AddBackward0>), 0, 1833055761472)
>>> y[0] = torch.abs_ (y[0])
>>> y, y._version, y.data_ptr()   # y has been modified inplace twice
(tensor([ 0.2000, -1.6000], grad_fn=<CopySlices>), 2, 1833055761472)
>>> y.sum().backward()            # backward() works
>>> x.grad
tensor([-1.,  1.])
>>>
>>> # using abs with assignment
>>> x = torch.tensor ([-1.2, -2.6], requires_grad = True)
>>> y = x + 1
>>> y, y._version, y.data_ptr()   # initial, unmodified version of y
(tensor([-0.2000, -1.6000], grad_fn=<AddBackward0>), 0, 1833055763200)
>>> y[0] = torch.abs (y[0])
>>> y, y._version, y.data_ptr()   # y has been modified inplace once
(tensor([ 0.2000, -1.6000], grad_fn=<CopySlices>), 1, 1833055763200)
>>> y.sum().backward()            # and backward() fails
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "<path_to_pytorch_install>\torch\_tensor.py", line 487, in backward
    torch.autograd.backward(
  File "<path_to_pytorch_install>\torch\autograd\__init__.py", line 200, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor []], which is output 0 of AsStridedBackward0, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
>>> x.grad                        # grad is None because backward() failed
>>> 

It’s easier to see what’s going on if we just look at abs_() without
also assigning into the indexed view of y, y[0] = torch.abs (y[0])`.

The following example shows how an explicit inplace-modification
with abs() breaks backward(), but switching to abs_() fixes things:

>>> # modify input to abs_() inplace separately from calling abs_()
>>> x = torch.tensor ([-1.2, 0.6], requires_grad = True)
>>> z = x + 1
>>> y = z.abs_()                   # modifies z inplace, so it stores a copy of z
>>> y
tensor([0.2000, 1.6000], grad_fn=<AbsBackward0>)
>>> z *= 1                         # modifies z again, but its copy remains unmodified for backward()
>>> z._version                     # z modified twice
2
>>> y.sum().backward()             # backward() works
>>> x.grad
tensor([-1.,  1.])
>>>
>>> # modify input to abs() inplace separately from calling abs()
>>> x = torch.tensor ([-1.2, 0.6], requires_grad = True)
>>> z = x + 1
>>> y = z.abs()                    # doesn't modify z -- no copy stored
>>> y
tensor([0.2000, 1.6000], grad_fn=<AbsBackward0>)
>>> z *= 1                         # modifies z, but no separate copy, so breaks backward() through abs()
>>> z._version                     # (z modified once)
1
>>> y.sum().backward()             # backward() fails
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "<path_to_pytorch_install>\torch\_tensor.py", line 487, in backward
    torch.autograd.backward(
  File "<path_to_pytorch_install>\torch\autograd\__init__.py", line 200, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [2]], which is output 0 of MulBackward0, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

Best.

K. Frank

Good explanation @KFrank, thanks

Hi @KFrank thanks for your detailed explanation, that makes it really clear! upon some testing, I still have some doubts, can you please share some insights?

from what I understand, during backward it needs to know where in the original data x did abs change, so that it inverts sign at the same position at the incoming gradient. i.e. the x.grad += incoming_gradient[x<0]*-1. If this is the case, what about other functions that also requires input data for gradient calculation, such as sqrt (df/dx=0.5/sqrt(x))?

import torch
x_init=torch.tensor([1., 4], requires_grad=True)
x=x_init+0
y=torch.sqrt(x)
x[0]=0
y.sum().backward()
print(f'x version: {x._version}')
print(f'y version: {y._version}')
print(f'x grad: {x_init.grad}')

The above code works and gives the correct grad value at x. However, if you change the third line to y=torch.abs(x)(or y=torch.log(x) etc.), it gives the modified by inplace operation error. What happens here? Thanks!

Hi Sam!

The point here is that sqrt() saves its output, rather than its input, for
use in the backward pass. (sqrt (x) could save its input, x, but then
in would have to recompute sqrt (x) from x in order to compute its
gradient. To be a little more efficient, it stores its output, sqrt (x),
instead.)

Try modifying sqrt (x)'s output, y, inplace, and see what happens.

Here is a post that goes through a few more examples of this sort of
thing:

Best.

K. Frank

1 Like

Hi Frank, one last question, if abs_() stores a copy of its input in the computation graph, wouldn’t this makes no different than its outplace version memory-wise? For example, for y = abs(x) there will be two tensors x and y in the computation graph; whereas for abs_(x) there is only one tensor x, but still you need to save an x value at version 0 for backward. So, what’s the point of such inplace ops?

From my understanding, there’s only a few functions that can really save memory by using their inplace version, such as sqrt, exp, tan, tanh, sigmoid, ReLU etc, because those derivatives can be expressed using output only. The only drawback is that you can’t do further inplace ops, because that would change the output. Am I correct?

Thanks again for your time!
Regards,
Sam

Indeed the memory savings from in-place is less pronounced when intermediary results are needed for backward anyway. This also kind of applies to the case where an in-place operation saves the outputs instead, since that prevents you from doing another in-place on that output.

In general though, in-place will still be useful for in two ways (1) memory savings for inference. (2) supporting the common pattern of modifying a small slice of a large tensor. (there should be memory savings here since in-place only needs to clone the small slice instead of the entire tensor in the case of an out-of-place)

1 Like