I have a complex loss computation over some meshes (faces, vertices) that I defined as follows:
import torch
def surface(max_m = 30):
S_tensor = torch.randn((5120, 30, 30, 30))
ABC = torch.randn((5120, 3, 3), requires_grad=True)
D_tensor = torch.randn((5120, 30, 30, 30), requires_grad=True)
for i in range(max_m):
for j in range(max_m):
for k in range(max_m):
if (i + j + k) <= max_m:
if i == j == k == 0:
# S_ijk = 1
S_tensor[:, i, j, k] = 1
else:
S_tensor[:, i, j, k] = ABC[:,0][:,0]*S_tensor[:,i-1,j,k]+ABC[:,0][:,1]*S_tensor[:,i,j-1,k]+ABC[:,0][:,2]*S_tensor[:,i,j,k-1]+D_tensor[:,i,j,k]
return S_tensor
s1 = surface()
s2 = surface()
loss = torch.linalg.norm(s1 - s2)
loss.backward()
I simplified the surface
function and only provided a minimal reproducible code.
The code give this error:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[14], line 4
2 s2 = surface()
3 loss = torch.linalg.norm(s1 - s2)
----> 4 loss.backward()
File ~/miniconda/lib/python3.10/site-packages/torch/_tensor.py:487, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
477 if has_torch_function_unary(self):
478 return handle_torch_function(
479 Tensor.backward,
480 (self,),
(...)
485 inputs=inputs,
486 )
--> 487 torch.autograd.backward(
488 self, gradient, retain_graph, create_graph, inputs=inputs
489 )
File ~/miniconda/lib/python3.10/site-packages/torch/autograd/__init__.py:200, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
195 retain_graph = create_graph
197 # The reason we repeat same the comment below is that
198 # some Python versions print out the first line of a multi-line function
199 # calls in the traceback and some print out the last line
--> 200 Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
201 tensors, grad_tensors_, retain_graph, create_graph, inputs,
202 allow_unreachable=True, accumulate_grad=True)
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [5120]], which is output 0 of AsStridedBackward0,
is at version 5453; expected version 5452 instead. Hint: enable anomaly detection to find the
operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
Running with torch.autograd.set_detect_anomaly(True).
, i get this error
/miniconda/lib/python3.10/site-packages/torch/autograd/__init__.py:200: UserWarning: Error detected in MulBackward0. Traceback of forward call that caused the error:
File "/miniconda/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/miniconda/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/miniconda/lib/python3.10/site-packages/ipykernel_launcher.py", line 17, in <module>
app.launch_new_instance()
File "/miniconda/lib/python3.10/site-packages/traitlets/config/application.py", line 1043, in launch_instance
app.start()
File "/miniconda/lib/python3.10/site-packages/ipykernel/kernelapp.py", line 712, in start
self.io_loop.start()
File "/miniconda/lib/python3.10/site-packages/tornado/platform/asyncio.py", line 199, in start
self.asyncio_loop.run_forever()
File "/miniconda/lib/python3.10/asyncio/base_events.py", line 603, in run_forever
self._run_once()
File "/miniconda/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once
handle._run()
File "/miniconda/lib/python3.10/asyncio/events.py", line 80, in _run
self._context.run(self._callback, *self._args)
File "/miniconda/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 510, in dispatch_queue
await self.process_one()
File "/miniconda/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 499, in process_one
await dispatch(*args)
File "/miniconda/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 406, in dispatch_shell
await result
File "/miniconda/lib/python3.10/site-packages/ipykernel/kernelbase.py", line 730, in execute_request
reply_content = await reply_content
File "/miniconda/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 383, in do_execute
res = shell.run_cell(
File "/miniconda/lib/python3.10/site-packages/ipykernel/zmqshell.py", line 528, in run_cell
return super().run_cell(*args, **kwargs)
File "/miniconda/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3009, in run_cell
result = self._run_cell(
File "/miniconda/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3064, in _run_cell
result = runner(coro)
File "/miniconda/lib/python3.10/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
coro.send(None)
File "/miniconda/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3269, in run_cell_async
has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
File "/miniconda/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3448, in run_ast_nodes
if await self.run_code(code, result, async_=asy):
File "/miniconda/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3508, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "/tmp/ipykernel_31266/1907077233.py", line 2, in <module>
s2 = surface()
File "/tmp/ipykernel_31266/668971176.py", line 15, in surface
S_tensor[:, i, j, k] = ABC[:,0][:,0]*S_tensor[:,i-1,j,k]+ABC[:,0][:,1]*S_tensor[:,i,j-1,k]+ABC[:,0][:,2]*S_tensor[:,i,j,k-1]+D_tensor[:,i,j,k]
(Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:114.)
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[16], line 4
2 s2 = surface()
3 loss = torch.linalg.norm(s1 - s2)
----> 4 loss.backward()
File ~/miniconda/lib/python3.10/site-packages/torch/_tensor.py:487, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
477 if has_torch_function_unary(self):
478 return handle_torch_function(
479 Tensor.backward,
480 (self,),
(...)
485 inputs=inputs,
486 )
--> 487 torch.autograd.backward(
488 self, gradient, retain_graph, create_graph, inputs=inputs
489 )
File ~/miniconda/lib/python3.10/site-packages/torch/autograd/__init__.py:200, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
195 retain_graph = create_graph
197 # The reason we repeat same the comment below is that
198 # some Python versions print out the first line of a multi-line function
199 # calls in the traceback and some print out the last line
--> 200 Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
201 tensors, grad_tensors_, retain_graph, create_graph, inputs,
202 allow_unreachable=True, accumulate_grad=True)
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [5120]], which is output 0 of AsStridedBackward0, is at version 5453; expected version 5452 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
Any help will be appreciated.
Thanks.