Hello,
I am trying to run the below code but I am getting an in place operation error for loss computation. Is there a way to resolve this without cloning the “target” tensor? Maybe a PyTorch function to directly subtract the specific elements from the x,y,z indices?
import torch
batch_size = 64
data = (torch.rand(batch_size, 50, 100) < 0.01).float().to_sparse()
target = torch.rand(batch_size, 50, 100) - 0.5
target.requires_grad = True
x, y, z = data.indices()
losses = torch.exp(target)
losses[x, y, z] = losses[x,y,z] - data.values() * target[x, y, z]
loss = losses.sum().backward()
Error:
RuntimeError Traceback (most recent call last)
Cell In[18], line 4
2 losses = torch.exp(target)
3 losses[x, y, z] = losses[x,y,z] - data.values() * target[x, y, z]
----> 4 loss = losses.sum().backward()
File ~/miniconda3/envs/sbtt-demo/lib/python3.9/site-packages/torch/_tensor.py:488, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
478 if has_torch_function_unary(self):
479 return handle_torch_function(
480 Tensor.backward,
481 (self,),
(…)
486 inputs=inputs,
487 )
→ 488 torch.autograd.backward(
489 self, gradient, retain_graph, create_graph, inputs=inputs
490 )
File ~/miniconda3/envs/sbtt-demo/lib/python3.9/site-packages/torch/autograd/init.py:197, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
192 retain_graph = create_graph
194 # The reason we repeat same the comment below is that
195 # some Python versions print out the first line of a multi-line function
196 # calls in the traceback and some print out the last line
…
→ 197 Variable.execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
198 tensors, grad_tensors, retain_graph, create_graph, inputs,
199 allow_unreachable=True, accumulate_grad=True)
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [64, 50, 100]], which is output 0 of ExpBackward0, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).