When inplace operation are allowed and when not

I’m trying to understand how the shortcuts work in neural networks models (e.g., ResNet models) implemented using Pytorch.

I’m trying the following ResNet block found on a github repository:

class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, in_planes, planes, stride=1,
                 shortcut_enabled = True, 
                 device='cuda'
                ):
        super(BasicBlock, self).__init__()
        
        self.shortcut_enabled = shortcut_enabled
        
        
        self.conv1    = nn.Conv2d(in_planes, planes, 
                         kernel_size=3, 
                         stride=stride, padding=1, bias=False,
                                  device=device)
        self.bn1      = nn.BatchNorm2d(planes, device=device)
        self.conv2    = nn.Conv2d(planes, planes, 
                        kernel_size=3,
                        stride=1, padding=1, bias=False,
                        device=device)
        self.bn2      = nn.BatchNorm2d(planes, device=device)

        if self.shortcut_enabled:
            self.shortcut = nn.Sequential()
            if stride != 1 or in_planes != self.expansion*planes:
                self.shortcut = nn.Sequential(nn.Conv2d(in_planes, self.expansion*planes,kernel_size=1, stride=stride, bias=False, device=device),nn.BatchNorm2d(self.expansion*planes,device=device)
)

    def forward(self, x, verbose=0):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        if self.shortcut_enabled:
            out += self.shortcut(x)
        out = F.relu(out)
        return out

This block of code runs (both in learning and evaluation mode) without any problem. However, since it makes an inplace operation on a Tensor, i.e. out += self.shortcut(x) , I expected an error like one of the variables needed for gradient computation has been modified by an inplace operation, as in other similar models that I tested in the past.

I would understand when, in Pytorch, an inplace operation can be made and when it is not allowed.
why can Resnet Block use in-place operation and It works when training.

Hi,
In-place operations produce the error you mentioned when a variable used in a computation graph (and hence will be involved in the backward call for gradient computation) is modified in-place and hence, during the time the backward call is made, has its value changed/at a different version from what it was when this variable was used in the graph (forward pass).

Since the value is now changed, it would lead to wrong computation of gradients - PyTorch flags this as an error once the backward pass is started.

TLDR - if there’s a mismatch in the tensor versions between when the tensor was used to construct the computation graph vs when the backward call is made, an error is produced.

In your case, if there’s no such error, there must be no conflicting tensor versions and you should be fine using the in-place ops.

Hi Jie (and Srishti)!

The part that can be confusing is that a given tensor used in the forward
pass may or may not be used in the backward pass. It depends on the math
of the gradient computation (and the details of how the relevant backward()
function is implemented). It’s not always obvious.

Here are some explanatory examples:

The derivative of exp (x) is exp (x). So it makes sense that pytorch saves
the output of exp (x) when it is computed during the forward pass so that it
can reuse it – rather than recompute it – during the backward pass. If you
modify the output of exp (x) (inplace), you will trigger an inplace error.

On the other hand, the input to exp (x) (that is, x itself) isn’t of any use in
computing the derivative of exp (x) (assuming you have saved exp (x)),
so pytorch doesn’t save x and you can modify it.

Conversely, the derivative of x**2 is 2 * x, so pytorch saves the input
to the x**2 computation, but not its output, for use in the backward pass.
Modifying x itself will now trigger an inplace error, but modifying the output
of the x**2 computation won’t.

Lastly, the derivative of 10 * x is just the scalar 10. Neither the (tensor)
input nor output of the 10 * x computation is useful for computing the
derivative during the backward pass, so pytorch saves neither, and you
can modifying either one without triggering an inplace error.

Consider:

>>> import torch
>>> print (torch.__version__)
1.13.0
>>>
>>> t = torch.arange (3.0, requires_grad = True)
>>> u = 2 * t                             # (so it won't be a leaf variable)
>>> uexp = u.exp()                        # uexp is the gradient of u.exp()
>>> loss = uexp.sum()                     # (some dummy loss)
>>> u[0] = 666                            # inplace modification of u
>>> loss.backward (retain_graph = True)   # okay
>>> uexp[0] = 666                         # inplace modification of uexp
>>> loss.backward()                       # inplace modification error
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "<path_to_pytorch_install>\torch\_tensor.py", line 487, in backward
    torch.autograd.backward(
  File "<path_to_pytorch_install>\torch\autograd\__init__.py", line 197, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [3]], which is output 0 of ExpBackward0, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
>>>
>>> t = torch.arange (3.0, requires_grad = True)
>>> u = 2 * t                             # (so it won't be a leaf variable)
>>> usq = u**2                            # 2 * u is the gradient of u**2
>>> loss = usq.sum()                      # (some dummy loss)
>>> usq[0] = 666                          # inplace modification of usq
>>> loss.backward (retain_graph = True)   # okay
>>> u[0] = 666                            # inplace modification of u
>>> loss.backward()                       # inplace modification error
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "<path_to_pytorch_install>\torch\_tensor.py", line 487, in backward
    torch.autograd.backward(
  File "<path_to_pytorch_install>\torch\autograd\__init__.py", line 197, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [3]], which is output 0 of struct torch::autograd::CopySlices, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
>>>
>>> t = torch.arange (3.0, requires_grad = True)
>>> u = 2 * t                             # (so it won't be a leaf variable)
>>> uten = 10 * u                         # gradient of 10 * u is just 10
>>> loss = uten.sum()                     # (some dummy loss)
>>> u[0] = 666                            # inplace modification of u
>>> loss.backward (retain_graph = True)   # okay
>>> uten[0] = 666                         # inplace modification of uten
>>> loss.backward()                       # also okay

Best.

K. Frank

3 Likes

Very clear explaination. Thank you !!!
In my case, because the equation “out += self.shortcut(x)” is like “10*x”, so pytorch doesn’t save any input and output. Also, Batchnorm save input, I can change its output.

In pixelCNN code (here for example pixelcnn-pytorch/masked_cnn_layer.py at master · axeloh/pixelcnn-pytorch · GitHub) one can multiply the network parameter (which is a leaf node with requires_grad=True) with a mask tensor but still doesn’t get any error! how is this possible?

are you referring to

class MaskedConv2d(...):
    def __init__(self, ...):
        ...
    def forward(self, x):
        self.weight.data *= self.mask

I remember reading from somewhere that Tensor.data creates a new tensor, so in-place operation is performed on that tensor, that’s why you don’t get RuntimeError: a leaf Variable that requires grad is being used in an in-place operation..

Also, version counter doesn’t increase after in-place operation on Tensor.data:

import torch
import torch.nn as nn

x=nn.Parameter(torch.tensor([1.,2.,3.]))
print(f'version before inplace operation: {x._version}')
x.data*=torch.tensor([True,False,True])
print(f'version after inplace operation: {x._version}')
x.sum().backward()
print(f'gradient at x: {x.grad}')

#version before inplace operation: 0
#version after inplace operation: 0
#gradient at x: tensor([1., 1., 1.])

As shown above, the version counter doesn’t increase. What’s more, the mask has no effect during backpropagation, not sure if this is what they desired. Normally if you applies a mask, this mask should also act on gradient:

import torch
import torch.nn as nn

x=nn.Parameter(torch.tensor([1.,2.,3.]))
y=x*1 # make it non-leaf
print(f'version before inplace operation: {y._version}')
y*=torch.tensor([True,False,True])
print(f'version after inplace operation: {y._version}')
y.sum().backward()
print(f'gradient at x: {x.grad}')

#version before inplace operation: 0
#version after inplace operation: 1
#gradient at x: tensor([1., 0., 1.])
1 Like