RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [16, 50176]], which is output 0 of AsStridedBackward0, is at version 8; expected version 7 instead. Hint: the backtrace furthe

I encountered the error in the title when trying to do multi-label classification.

The model is:

    def __init__(self, n_outputs=1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=96, kernel_size=11, stride=4),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(96, 256, 5, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(256, 384, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(384, 384, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(384, 256, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Flatten(start_dim=1, end_dim=-1)
        )
        self.patient_overall = nn.Sequential(
            nn.Dropout(p=0.5, inplace=True),
            nn.Linear(in_features=14*14*256, out_features=n_outputs),
            nn.Sigmoid()
        )
        self.c1 = nn.Sequential(
            nn.Dropout(p=0.5, inplace=True),
            nn.Linear(in_features=14*14*256, out_features=n_outputs),
            nn.Sigmoid()
        )
        self.c2 = nn.Sequential(
            nn.Dropout(p=0.5, inplace=True),
            nn.Linear(in_features=14*14*256, out_features=n_outputs),
            nn.Sigmoid()
        )
        self.c3 = nn.Sequential(
            nn.Dropout(p=0.5, inplace=True),
            nn.Linear(in_features=14*14*256, out_features=n_outputs),
            nn.Sigmoid()
        )
        self.c4 = nn.Sequential(
            nn.Dropout(p=0.5, inplace=True),
            nn.Linear(in_features=14*14*256, out_features=n_outputs),
            nn.Sigmoid()
        )
        self.c5 = nn.Sequential(
            nn.Dropout(p=0.5, inplace=True),
            nn.Linear(in_features=14*14*256, out_features=n_outputs),
            nn.Sigmoid()
        )
        self.c6 = nn.Sequential(
            nn.Dropout(p=0.5, inplace=True),
            nn.Linear(in_features=14*14*256, out_features=n_outputs),
            nn.Sigmoid()
        )
        self.c7 = nn.Sequential(
            nn.Dropout(p=0.5, inplace=True),
            nn.Linear(in_features=14*14*256, out_features=n_outputs),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        x = self.net(x)
        
        return {
            'patient_overall': self.patient_overall(x),
            'c1': self.c1(x),
            'c2': self.c2(x),
            'c3': self.c3(x),
            'c4': self.c4(x),
            'c5': self.c5(x),
            'c6': self.c6(x),
            'c7': self.c7(x)
        }

I have also defined the criterion and the training loop:

def criterion(loss_func, outputs, targets):
    c_output = torch.cat(tuple([val for val in outputs.values()]), 1)
    loss = loss_func(c_output, targets)
    return loss

def training(model, device, lr_rate, epochs, train_loader):
    num_epochs = epochs
    losses = []
    checkpoint_losses = []
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr_rate)
    n_total_steps = len(train_loader)
    model.train()
    loss_func = nn.BCELoss()
    for epoch in range(num_epochs):
        print('----- Epoch: {epoch} -----')
        for i, sampled_batch in enumerate(train_loader):
            print('--- {i}')
            inputs = sampled_batch[0]
            targets = sampled_batch[1]
            
            outputs = model(inputs)
            
            optimizer.zero_grad()
            loss = criterion(loss_func, outputs, targets)
            losses.append(loss.item())

            loss.backward()
            optimizer.step()
            
            if (i+1) % (int(n_total_steps)) == 0:
                checkpoint_loss = torch.tensor(losses).mean().item()
                checkpoint_losses.append(checkpoint_loss)
                print (f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{n_total_steps}], Loss: {checkpoint_loss:.4f}')
                
    return checkpoint_losses

I’m also joining the full error message after setting up torch.autograd.set_detect_anomaly to True.

/opt/conda/lib/python3.7/site-packages/torch/autograd/__init__.py:175: UserWarning: Error detected in AddmmBackward0. Traceback of forward call that caused the error:
  File "/opt/conda/lib/python3.7/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/opt/conda/lib/python3.7/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/opt/conda/lib/python3.7/site-packages/ipykernel_launcher.py", line 17, in <module>
    app.launch_new_instance()
  File "/opt/conda/lib/python3.7/site-packages/traitlets/config/application.py", line 976, in launch_instance
    app.start()
  File "/opt/conda/lib/python3.7/site-packages/ipykernel/kernelapp.py", line 712, in start
    self.io_loop.start()
  File "/opt/conda/lib/python3.7/site-packages/tornado/platform/asyncio.py", line 199, in start
    self.asyncio_loop.run_forever()
  File "/opt/conda/lib/python3.7/asyncio/base_events.py", line 541, in run_forever
    self._run_once()
  File "/opt/conda/lib/python3.7/asyncio/base_events.py", line 1786, in _run_once
    handle._run()
  File "/opt/conda/lib/python3.7/asyncio/events.py", line 88, in _run
    self._context.run(self._callback, *self._args)
  File "/opt/conda/lib/python3.7/site-packages/ipykernel/kernelbase.py", line 510, in dispatch_queue
    await self.process_one()
  File "/opt/conda/lib/python3.7/site-packages/ipykernel/kernelbase.py", line 499, in process_one
    await dispatch(*args)
  File "/opt/conda/lib/python3.7/site-packages/ipykernel/kernelbase.py", line 406, in dispatch_shell
    await result
  File "/opt/conda/lib/python3.7/site-packages/ipykernel/kernelbase.py", line 730, in execute_request
    reply_content = await reply_content
  File "/opt/conda/lib/python3.7/site-packages/ipykernel/ipkernel.py", line 387, in do_execute
    cell_id=cell_id,
  File "/opt/conda/lib/python3.7/site-packages/ipykernel/zmqshell.py", line 528, in run_cell
    return super().run_cell(*args, **kwargs)
  File "/opt/conda/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 2975, in run_cell
    raw_cell, store_history, silent, shell_futures, cell_id
  File "/opt/conda/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3029, in _run_cell
    return runner(coro)
  File "/opt/conda/lib/python3.7/site-packages/IPython/core/async_helpers.py", line 78, in _pseudo_sync_runner
    coro.send(None)
  File "/opt/conda/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3257, in run_cell_async
    interactivity=interactivity, compiler=compiler, result=result)
  File "/opt/conda/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3472, in run_ast_nodes
    if (await self.run_code(code, result,  async_=asy)):
  File "/opt/conda/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3552, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_17/2261024119.py", line 1, in <module>
    checkpoint_losses = training(model, device, 0.0001, 10, rsna_dataloader)
  File "/tmp/ipykernel_17/1401717521.py", line 17, in training
    outputs = model(inputs)
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/tmp/ipykernel_17/3761865426.py", line 76, in forward
    'c6': self.c6(x),
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/container.py", line 141, in forward
    input = module(input)
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/linear.py", line 103, in forward
    return F.linear(input, self.weight, self.bias)
 (Triggered internally at  ../torch/csrc/autograd/python_anomaly_mode.cpp:104.)
  allow_unreachable=True, accumulate_grad=True)  # Calls into the C++ engine to run the backward pass

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/tmp/ipykernel_17/2261024119.py in <module>
----> 1 checkpoint_losses = training(model, device, 0.0001, 10, rsna_dataloader)

/tmp/ipykernel_17/1401717521.py in training(model, device, lr_rate, epochs, train_loader)
     21             losses.append(loss.item())
     22 
---> 23             loss.backward()
     24             optimizer.step()
     25 

/opt/conda/lib/python3.7/site-packages/torch/_tensor.py in backward(self, gradient, retain_graph, create_graph, inputs)
    361                 create_graph=create_graph,
    362                 inputs=inputs)
--> 363         torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
    364 
    365     def register_hook(self, hook):

/opt/conda/lib/python3.7/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    173     Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    174         tensors, grad_tensors_, retain_graph, create_graph, inputs,
--> 175         allow_unreachable=True, accumulate_grad=True)  # Calls into the C++ engine to run the backward pass
    176 
    177 def grad(

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [16, 50176]], which is output 0 of AsStridedBackward0, is at version 8; expected version 7 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

It seems like the error occurs only when the forward pass executes the c6 part but I don’t understand why.

I hope you can help me solve this issue.

Hi @TAHTAH98,

Can replace all of the inplace=True statements with inplace=False for all your nn.Dropout() layers?

1 Like

It works! Thank you for your help.