When inplace operation are allowed and when not

Hi Jie (and Srishti)!

The part that can be confusing is that a given tensor used in the forward
pass may or may not be used in the backward pass. It depends on the math
of the gradient computation (and the details of how the relevant backward()
function is implemented). It’s not always obvious.

Here are some explanatory examples:

The derivative of exp (x) is exp (x). So it makes sense that pytorch saves
the output of exp (x) when it is computed during the forward pass so that it
can reuse it – rather than recompute it – during the backward pass. If you
modify the output of exp (x) (inplace), you will trigger an inplace error.

On the other hand, the input to exp (x) (that is, x itself) isn’t of any use in
computing the derivative of exp (x) (assuming you have saved exp (x)),
so pytorch doesn’t save x and you can modify it.

Conversely, the derivative of x**2 is 2 * x, so pytorch saves the input
to the x**2 computation, but not its output, for use in the backward pass.
Modifying x itself will now trigger an inplace error, but modifying the output
of the x**2 computation won’t.

Lastly, the derivative of 10 * x is just the scalar 10. Neither the (tensor)
input nor output of the 10 * x computation is useful for computing the
derivative during the backward pass, so pytorch saves neither, and you
can modifying either one without triggering an inplace error.

Consider:

>>> import torch
>>> print (torch.__version__)
1.13.0
>>>
>>> t = torch.arange (3.0, requires_grad = True)
>>> u = 2 * t                             # (so it won't be a leaf variable)
>>> uexp = u.exp()                        # uexp is the gradient of u.exp()
>>> loss = uexp.sum()                     # (some dummy loss)
>>> u[0] = 666                            # inplace modification of u
>>> loss.backward (retain_graph = True)   # okay
>>> uexp[0] = 666                         # inplace modification of uexp
>>> loss.backward()                       # inplace modification error
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "<path_to_pytorch_install>\torch\_tensor.py", line 487, in backward
    torch.autograd.backward(
  File "<path_to_pytorch_install>\torch\autograd\__init__.py", line 197, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [3]], which is output 0 of ExpBackward0, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
>>>
>>> t = torch.arange (3.0, requires_grad = True)
>>> u = 2 * t                             # (so it won't be a leaf variable)
>>> usq = u**2                            # 2 * u is the gradient of u**2
>>> loss = usq.sum()                      # (some dummy loss)
>>> usq[0] = 666                          # inplace modification of usq
>>> loss.backward (retain_graph = True)   # okay
>>> u[0] = 666                            # inplace modification of u
>>> loss.backward()                       # inplace modification error
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "<path_to_pytorch_install>\torch\_tensor.py", line 487, in backward
    torch.autograd.backward(
  File "<path_to_pytorch_install>\torch\autograd\__init__.py", line 197, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [3]], which is output 0 of struct torch::autograd::CopySlices, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
>>>
>>> t = torch.arange (3.0, requires_grad = True)
>>> u = 2 * t                             # (so it won't be a leaf variable)
>>> uten = 10 * u                         # gradient of 10 * u is just 10
>>> loss = uten.sum()                     # (some dummy loss)
>>> u[0] = 666                            # inplace modification of u
>>> loss.backward (retain_graph = True)   # okay
>>> uten[0] = 666                         # inplace modification of uten
>>> loss.backward()                       # also okay

Best.

K. Frank

5 Likes