Inplace Dropout causes "RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.HalfTensor [256, 256, 11, 11]], which is output 0 of ReluBackward0, is at version 2;expected version 1instead

Hi everyone, hope you are having a great time.
I recently faced this issue, where back in 1.8 or 1.9 I had to train a model with dropout’s inplace=Flase or otherwise it would crash, now in 1.11 and 1.13.1 I cant seem to flip that back to True and seems I’m stuck with dropout inplace=False! if I set that to True, I get this error:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.HalfTensor [256, 256, 11, 11]], which is output 0 of ReluBackward0, is at version 2; expected version 1 instead.

the full stacktrace looks like this:

/home/hossein/anaconda3/lib/python3.9/site-packages/torch/autograd/__init__.py:197: UserWarning: Error detected in ReluBackward0. Traceback of forward call that caused the error:
  File "/home/hossein/pytorch-image-models/train.py", line 806, in <module>
    main()
  File "/home/hossein/anaconda3/lib/python3.9/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper
    return f(*args, **kwargs)
  File "/home/hossein/pytorch-image-models/train.py", line 603, in main
    train_metrics = train_one_epoch(epoch, model, loader_train, optimizer, train_loss_fn, args,
  File "/home/hossein/pytorch-image-models/train.py", line 670, in train_one_epoch
    output = model(input)
  File "/home/hossein/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/hossein/pytorch-image-models/timm/models/simplenet.py", line 210, in forward
    out = self.features(x)
  File "/home/hossein/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/hossein/anaconda3/lib/python3.9/site-packages/torch/nn/modules/container.py", line 204, in forward
    input = module(input)
  File "/home/hossein/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/hossein/anaconda3/lib/python3.9/site-packages/torch/nn/modules/activation.py", line 102, in forward
    return F.relu(input, inplace=self.inplace)
  File "/home/hossein/anaconda3/lib/python3.9/site-packages/torch/nn/functional.py", line 1455, in relu
    result = torch.relu_(input)
  File "/home/hossein/anaconda3/lib/python3.9/site-packages/torch/fx/traceback.py", line 57, in format_stack
    return traceback.format_stack()
 (Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:114.)
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
Traceback (most recent call last):
  File "/home/hossein/pytorch-image-models/train.py", line 806, in <module>
    main()
  File "/home/hossein/anaconda3/lib/python3.9/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper
    return f(*args, **kwargs)
  File "/home/hossein/pytorch-image-models/train.py", line 603, in main
    train_metrics = train_one_epoch(epoch, model, loader_train, optimizer, train_loss_fn, args,
  File "/home/hossein/pytorch-image-models/train.py", line 678, in train_one_epoch
    loss_scaler(loss, optimizer,
  File "/home/hossein/pytorch-image-models/timm/utils/cuda.py", line 43, in __call__
    self._scaler.scale(loss).backward(create_graph=create_graph)
  File "/home/hossein/anaconda3/lib/python3.9/site-packages/torch/_tensor.py", line 488, in backward
    torch.autograd.backward(
  File "/home/hossein/anaconda3/lib/python3.9/site-packages/torch/autograd/__init__.py", line 197, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.HalfTensor [256, 256, 11, 11]], which is output 0 of ReluBackward0, is at version 2; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 1752876) of binary: /home/hossein/anaconda3/bin/python3
Traceback (most recent call last):
  File "/home/hossein/anaconda3/lib/python3.9/runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/hossein/anaconda3/lib/python3.9/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/hossein/anaconda3/lib/python3.9/site-packages/torch/distributed/launch.py", line 195, in <module>
    main()
  File "/home/hossein/anaconda3/lib/python3.9/site-packages/torch/distributed/launch.py", line 191, in main
    launch(args)
  File "/home/hossein/anaconda3/lib/python3.9/site-packages/torch/distributed/launch.py", line 176, in launch
    run(args)
  File "/home/hossein/anaconda3/lib/python3.9/site-packages/torch/distributed/run.py", line 753, in run
    elastic_launch(
  File "/home/hossein/anaconda3/lib/python3.9/site-packages/torch/distributed/launcher/api.py", line 132, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/home/hossein/anaconda3/lib/python3.9/site-packages/torch/distributed/launcher/api.py", line 246, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
train.py FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2023-02-04_13:42:08
  host      : hossein-pc
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 1752876)
  error_file: /tmp/torchelastic_40eggmop/none_e2lzr3bg/attempt_0/0/error.json
  traceback : Traceback (most recent call last):
    File "/home/hossein/anaconda3/lib/python3.9/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper
      return f(*args, **kwargs)
    File "/home/hossein/pytorch-image-models/train.py", line 603, in main
      train_metrics = train_one_epoch(epoch, model, loader_train, optimizer, train_loss_fn, args,
    File "/home/hossein/pytorch-image-models/train.py", line 678, in train_one_epoch
      loss_scaler(loss, optimizer,
    File "/home/hossein/pytorch-image-models/timm/utils/cuda.py", line 43, in __call__
      self._scaler.scale(loss).backward(create_graph=create_graph)
    File "/home/hossein/anaconda3/lib/python3.9/site-packages/torch/_tensor.py", line 488, in backward
      torch.autograd.backward(
    File "/home/hossein/anaconda3/lib/python3.9/site-packages/torch/autograd/__init__.py", line 197, in backward
      Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.HalfTensor [256, 256, 11, 11]], which is output 0 of ReluBackward0, is at version 2; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
  
============================================================

As for the model, its a plain cnn(vgg like) with nothing fancy about. just good old cnn!
whats going on and how can I get rid of this annoyance?
Thanks a lot in advance

The error is not an annoyance and is rightfully raised since you are using disallowed inplace operations, which would otherwise create invalid gradients.
@KFrank gives a few great examples as e.g. in this post.

Thanks a lot, please kindly correct me if I’m wrong.
The issue that I’m having is, the dropout being inplace shouldn’t cause this, especially given the fact that the model is instantiated a new. if I was to do a forward with inplace=False and then just befor doing the backward set the inplace to True, I undrestand that but why does it occur when the model is instantiated completely new?
Does Pytorch not create a pristine graph at each model instantiation?
consider the following example:

>>> import torch
>>> import torch.nn as nn
>>> c1=nn.Conv2d(3,3,kernel_size=3)
>>> r1=nn.ReLU(inplace=True)
>>> l1=nn.Linear(12,10)
>>> f1=nn.Flatten()
>>> i = torch.randn(size=(2,3,4,4))
>>> out = l1(f1(r1(c1(i))))
>>> out.mean().backward()

why doesnt the inplace operation issue any errors here?

No, since there is no computation graph at model instantiation.

Most likely because the outputs of c1 are not needed for the gradient computation as explained in the linked post.

Im unsure if you’ve completely read the linked post, but inplace operations on activation tensors are disallowed and will raise an error if this specific activation is needed for the gradient computation.
If no error is raised, the gradient would be wrong.
The model instantiation is irrelevant since the inplace operation will be used during the forward pass where the computation graph is created.

Highlighting some important parts from the link:

The derivative of exp (x) is exp (x) . So it makes sense that pytorch saves
the output of exp (x) when it is computed during the forward pass so that it
can reuse it – rather than recompute it – during the backward pass. If you
modify the output of exp (x) (inplace), you will trigger an inplace error.

As explained above, the manipulation of the output of exp(x) is disallowed and will raise an error, since this tensor is needed for the gradient computation. The post also shares an executable code snippet to reproduce this behavior with multiple examples (e.g. using other operations which neede the input for their gradient computation).

I’ve read the whole post. but the thing is, unlike exp which is reliant on its output for backprop, both relu and dropout have inplace arguments, meaning their backprop is not reliant on their output.

For ReLU, if I recall correctly, in the backward pass, we return 1 if x>0 and 0 otherwise, likewise for dropout, we scale all nonzero elements during the forward pass, and since the zero elements are in either cases 0, the gradient should flow just fine, it seems to me, there shouldnt be an issue for an inplace relu or inplace dropout to work. am I wrong in this?
The following snippet shows this as well that an inplace relu shouldnt pose any issues:

import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = self._make_layers()
        self.classifier = nn.Linear(9,10)

    def forward(self, x):
        out = self.features(x)
        out = F.max_pool2d(out, kernel_size=out.size()[2:])
        out = F.dropout2d(out, 0.5, training=self.training)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        return out

    def _make_layers(self):
        model = nn.Sequential(nn.Conv2d(3,3,kernel_size=3),
                      nn.BatchNorm2d(3),
                      nn.ReLU(inplace=True),
                      #nn.Dropout2d(inplace=True),
                      nn.MaxPool2d(2,2),
                      
                      nn.Conv2d(3,6,kernel_size=3),
                      nn.BatchNorm2d(6),
                      nn.ReLU(inplace=True),
                      #nn.Dropout2d(inplace=True),
                      nn.MaxPool2d(2,2),
                      
                      nn.Conv2d(6,9,kernel_size=3),
                      nn.BatchNorm2d(9),
                      nn.ReLU(inplace=True),
                      #nn.Dropout2d(inplace=True),
                      nn.MaxPool2d(2,2),
                      )
        for m in model.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_uniform_(m.weight.data, gain=nn.init.calculate_gain("relu"))
        return model
    
if __name__ == "__main__":
    model = SimpleNet()
    input_dummy = torch.randn(size=(2, 3, 32, 32))
    out = model(input_dummy)
    out.mean().backward()

Now if you add the dropout as inplace, it crashes, in which case, I wonder why bother adding the inplace for dropout cant be done inplace.

Thank you very much in advance for your time an kind patience.

That’s not the case, since ReLU uses its output for the gradient computation as defined here and as shown in this code snippet:

x = torch.randn(1, 10, requires_grad=True)

# works
relu = nn.ReLU()
out = relu(x)
out.mean().backward()

# out-of-place dropout still works
out = relu(x)
out = F.dropout(out, inplace=False)
out.mean().backward()

# other out-of-place ops also work
out = relu(x)
out = out + 1.
out.mean().backward()

# inplace manipulation of relu's output fails
out = relu(x)
out = F.dropout(out, inplace=True)
out.mean().backward()
# RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [1, 10]], which is output 0 of ReluBackward0, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

# same as here
out = relu(x)
out.add_(1.)
out.mean().backward()
# RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [1, 10]], which is output 0 of ReluBackward0, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
2 Likes