Sorting Triggers Versioning Errors

Hello,

I would like to ask why am I getting following error message when sorting an array of tensors during enable_grad:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [192, 1000]], which is output 0 of TBackward, is at version 3; expected version 2 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

I have confirmed that the error comes from sort operation. The sort operation is not in-place (Pythons sorted). I have also try to do argsort + gather, but the same error appears. Filter operation for example doesn’t trigger the versioning error. Could you guys give me some pointers on this?

Backstory: I am feeding triplets into my transformer network and I want to select n relevant triplets by sorting them by their loss.

The solution is I think to do two forward passes. One that selects triplets and the second one that constructs autograd graph. Am I right on this?

Thank you for your help. :slight_smile:

Hi Nimu!

You have to find / deduce the inplace operation that is causing your
problem and work around it.

You can – at the cost of additional storage – make a copy of the tensor
in question which you can then modify inplace, e.g.:

x = x.clone()   # assume that the original x will be used by backward()
x.copy_ (y)     # modify the cloned x inplace -- okay
...
z.backward()    # won't break

Some comments:

An inplace operation does not necessarily cause an error – it depends on
whether the tensor being modified has been stored in the computation
graph for use during backward(). Some operations make use of their
input tensors during backward(); others do not.

Consider:

>>> import torch
>>> print (torch.__version__)
1.10.2
>>>
>>> x0 = torch.ones (1, requires_grad = True)
>>> x1 = 2 * x0
>>> x2 = torch.exp (x1)   # backward() works after inplace operation
>>> x1.copy_ (torch.ones (1))
tensor([1.], grad_fn=<CopyBackwards>)
>>> x3 = x1 + x2
>>> x3.backward()
>>> x0.grad
tensor([14.7781])
>>>
>>> y0 = torch.ones (1, requires_grad = True)
>>> y1 = 2 * y0
>>> y2 = y1 ** 2   # inplace operation would break backward()
>>> y1 = y1.clone()   # but clone() saves the day
>>> y1.copy_ (torch.ones (1))
tensor([1.], grad_fn=<CopyBackwards>)
>>> y3 = y1 + y2
>>> y3.backward()
>>> y0.grad
tensor([8.])
>>>
>>> z0 = torch.ones (1, requires_grad = True)
>>> z1 = 2 * z0
>>> z2 = z1 ** 2   # inplace operation breaks backward()
>>> z1.copy_ (torch.ones (1))
tensor([1.], grad_fn=<CopyBackwards>)
>>> z3 = z1 + z2
>>> z3.backward()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "<path_to_pytorch_install>\torch\_tensor.py", line 307, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "<path_to_pytorch_install>\torch\autograd\__init__.py", line 156, in backward
    allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [1]], which is output 0 of struct torch::autograd::CopyBackwards, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
>>> z0.grad   # grad not populated because backward() broke

Best.

K. Frank

Thank you very much for help and the example. :slight_smile:

I couldn’t figure out the reason why Pythons sorted did trigger the error. However I was able to solve it by doing two forward passes.