RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

I have a complex loss computation over some meshes (faces, vertices) that I defined as follows:

import torch

def surface(max_m = 30):
    S_tensor = torch.randn((5120, 30, 30, 30))
    ABC = torch.randn((5120, 3, 3), requires_grad=True)
    D_tensor = torch.randn((5120, 30, 30, 30), requires_grad=True)
    for i in range(max_m):
        for j in range(max_m):
            for k in range(max_m):
                if (i + j + k) <= max_m:
                    if i == j == k == 0:
                        # S_ijk = 1
                        S_tensor[:, i, j, k] = 1
                    else:
                        S_tensor[:, i, j, k] = ABC[:,0][:,0]*S_tensor[:,i-1,j,k]+ABC[:,0][:,1]*S_tensor[:,i,j-1,k]+ABC[:,0][:,2]*S_tensor[:,i,j,k-1]+D_tensor[:,i,j,k]
    return S_tensor

s1 = surface()
s2 = surface()
loss = torch.linalg.norm(s1 - s2)
loss.backward()

I simplified the surface function and only provided a minimal reproducible code.
The code give this error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[14], line 4
      2 s2 = surface()
      3 loss = torch.linalg.norm(s1 - s2)
----> 4 loss.backward()

File ~/miniconda/lib/python3.10/site-packages/torch/_tensor.py:487, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
    477 if has_torch_function_unary(self):
    478     return handle_torch_function(
    479         Tensor.backward,
    480         (self,),
   (...)
    485         inputs=inputs,
    486     )
--> 487 torch.autograd.backward(
    488     self, gradient, retain_graph, create_graph, inputs=inputs
    489 )

File ~/miniconda/lib/python3.10/site-packages/torch/autograd/__init__.py:200, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    195     retain_graph = create_graph
    197 # The reason we repeat same the comment below is that
    198 # some Python versions print out the first line of a multi-line function
    199 # calls in the traceback and some print out the last line
--> 200 Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    201     tensors, grad_tensors_, retain_graph, create_graph, inputs,
    202     allow_unreachable=True, accumulate_grad=True)

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [5120]], which is output 0 of AsStridedBackward0, 
is at version 5453; expected version 5452 instead. Hint: enable anomaly detection to find the
 operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

Running with torch.autograd.set_detect_anomaly(True)., i get this error

/miniconda/lib/python3.10/site-packages/torch/autograd/__init__.py:200: UserWarning: Error detected in MulBackward0. Traceback of forward call that caused the error:
  File "/miniconda/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/miniconda/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/miniconda/lib/python3.10/site-packages/ipykernel_launcher.py", line 17, in <module>
    app.launch_new_instance()
  File "/miniconda/lib/python3.10/site-packages/traitlets/config/application.py", line 1043, in launch_instance
    app.start()
  File "/miniconda/lib/python3.10/site-packages/ipykernel/kernelapp.py", line 712, in start
    self.io_loop.start()
  File "/miniconda/lib/python3.10/site-packages/tornado/platform/asyncio.py", line 199, in start
    self.asyncio_loop.run_forever()
  File "/miniconda/lib/python3.10/asyncio/base_events.py", line 603, in run_forever
    self._run_once()
  File "/miniconda/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once
    handle._run()
  File "/miniconda/lib/python3.10/asyncio/events.py", line 80, in _run
    self._context.run(self._callback, *self._args)
  File "/miniconda/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 510, in dispatch_queue
    await self.process_one()
  File "/miniconda/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 499, in process_one
    await dispatch(*args)
  File "/miniconda/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 406, in dispatch_shell
    await result
  File "/miniconda/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 730, in execute_request
    reply_content = await reply_content
  File "/miniconda/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 383, in do_execute
    res = shell.run_cell(
  File "/miniconda/lib/python3.10/site-packages/ipykernel/zmqshell.py", line 528, in run_cell
    return super().run_cell(*args, **kwargs)
  File "/miniconda/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3009, in run_cell
    result = self._run_cell(
  File "/miniconda/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3064, in _run_cell
    result = runner(coro)
  File "/miniconda/lib/python3.10/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
    coro.send(None)
  File "/miniconda/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3269, in run_cell_async
    has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
  File "/miniconda/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3448, in run_ast_nodes
    if await self.run_code(code, result, async_=asy):
  File "/miniconda/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3508, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_31266/1907077233.py", line 2, in <module>
    s2 = surface()
  File "/tmp/ipykernel_31266/668971176.py", line 15, in surface
    S_tensor[:, i, j, k] = ABC[:,0][:,0]*S_tensor[:,i-1,j,k]+ABC[:,0][:,1]*S_tensor[:,i,j-1,k]+ABC[:,0][:,2]*S_tensor[:,i,j,k-1]+D_tensor[:,i,j,k]
 (Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:114.)
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[16], line 4
      2 s2 = surface()
      3 loss = torch.linalg.norm(s1 - s2)
----> 4 loss.backward()

File ~/miniconda/lib/python3.10/site-packages/torch/_tensor.py:487, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
    477 if has_torch_function_unary(self):
    478     return handle_torch_function(
    479         Tensor.backward,
    480         (self,),
   (...)
    485         inputs=inputs,
    486     )
--> 487 torch.autograd.backward(
    488     self, gradient, retain_graph, create_graph, inputs=inputs
    489 )

File ~/miniconda/lib/python3.10/site-packages/torch/autograd/__init__.py:200, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    195     retain_graph = create_graph
    197 # The reason we repeat same the comment below is that
    198 # some Python versions print out the first line of a multi-line function
    199 # calls in the traceback and some print out the last line
--> 200 Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    201     tensors, grad_tensors_, retain_graph, create_graph, inputs,
    202     allow_unreachable=True, accumulate_grad=True)

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [5120]], which is output 0 of AsStridedBackward0, is at version 5453; expected version 5452 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

Any help will be appreciated.

Thanks.

Hi @jpainam! Like the error states, you need a varaible’s value (S_tensor here) to compute the gradient for the loss, yet you are modifying it in place, which confuses torch. We may say PyTorch is dynamic, but not dynamic enough to reinvent math :slight_smile:

A potential fix would be to create a deep copy of the tensor you are now changing in place. This code should run:

import torch

def surface(max_m = 30):
    S_tensor = torch.randn((5120, 30, 30, 30))
    S_tensor_2 = S_tensor.clone()
    ABC = torch.randn((5120, 3, 3), requires_grad=True)
    D_tensor = torch.randn((5120, 30, 30, 30), requires_grad=True)
    for i in range(max_m):
        for j in range(max_m):
            for k in range(max_m):
                if (i + j + k) <= max_m:
                    if i == j == k == 0:
                        # S_ijk = 1
                        S_tensor_2[:, i, j, k] = 1
                    else:
                        S_tensor_2[:, i, j, k] = ABC[:,0][:,0]*S_tensor[:,i-1,j,k]+ABC[:,0][:,1]*S_tensor[:,i,j-1,k]+ABC[:,0][:,2]*S_tensor[:,i,j,k-1]+D_tensor[:,i,j,k]
    return S_tensor_2

s1 = surface()
s2 = surface()
loss = torch.linalg.norm(s1 - s2)
loss.backward()

However, are you sure you need all those loops for the computation? You are using torch after all :wink:

I haven’t found a way to leverage PyTorch to do these computations. Any suggestions? :slight_smile:

This operation S_tensor[:, i, j, k] = ABC[:,0][:,0]*S_tensor[:,i-1,j,k]+ABC[:,0][:,1]*S_tensor[:,i,j-1,k]+ABC[:,0][:,2]*S_tensor[:,i,j,k-1]+D_tensor[:,i,j,k] ue previous computed value of S_tensor, cloning it the way you did, doesnt do the same thing.

Value of S_tensor[:, i,j-1,k] or S_tensor[:, i, j, k-1] are not updated.

I see that you have an issue with the updated values of S_tensor not being used in the loops. In this case you can do the cloning in-place when using the values:

import torch

def surface(max_m = 30):
    S_tensor = torch.randn((100, max_m, max_m, max_m))
    ABC = torch.randn((100, 3, 3), requires_grad=True)
    D_tensor = torch.randn((100, max_m, max_m, max_m), requires_grad=True)
    for i in range(max_m):
        for j in range(max_m):
            for k in range(max_m):
                if (i + j + k) <= max_m:
                    if i == j == k == 0:
                        # S_ijk = 1
                        S_tensor[:, i, j, k] = 1
                    else:
                        S_tensor[:, i, j, k] = ABC[:,0][:,0]*S_tensor[:,i-1,j,k].clone()+ABC[:,0][:,1]*S_tensor[:,i,j-1,k].clone()+ABC[:,0][:,2]*S_tensor[:,i,j,k-1].clone()+D_tensor[:,i,j,k]
    return S_tensor

s1 = surface(20)
s2 = surface(20)
loss = torch.linalg.norm(s1 - s2)
loss.backward()

I changes some of the values you used to get it to run quicker but the overall code is the same…

https://nieznanm.medium.com/runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-85d0d207623

Cloning the part of the tensor used on the right-hand side of the assignment expression. S_tensor[...].clone()