Cast to long autograd

I want to make tensor.long() differentiable.

import torch

class CastToLong(torch.autograd.Function):

    @staticmethod
    def forward(ctx, tensor: torch.Tensor):
        return tensor.long()

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output


torch.manual_seed(0)
x = torch.randn(3,3)
idx = torch.tensor([0,1], dtype=torch.float32, requires_grad=True)

idx = CastToLong.apply(idx)
y = x[idx]

y.sum().backward()

I receive the traceback

Traceback (most recent call last):
  File "/home/dizcza/PycharmProjects/EmbedderSDR/ignored/visible/idx_autograd.py", line 21, in <module>
    y.sum().backward()
  File "/home/dizcza/anaconda3/envs/embsdr/lib/python3.6/site-packages/torch/tensor.py", line 195, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/dizcza/anaconda3/envs/embsdr/lib/python3.6/site-packages/torch/autograd/__init__.py", line 99, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: range.second - range.first == t.size() INTERNAL ASSERT FAILED at /opt/conda/conda-bld/pytorch_1579022119164/work/torch/csrc/autograd/generated/Functions.cpp:55, please report a bug to PyTorch. inconsistent range for TensorList output

I was not sure about opening an issue on GitHub and decided firstly to post the problem here. Surely, this is not a typical usage of pytorch autograd. Nevertheless, for me, it makes sense because essentially the values after .long() are exactly the same (idx are already integers) and thus CastToLong should just broadcast the gradient backward.

Should I open an issue or let’s discuss it here first?

1 Like

Hi,

You won’t be able to make a long Tensor require gradients.
And even if you trick pytorch into doing it, no differentiable op is implemented for integer types so you will have to reimplement everything.

Also, keep in mind that the indexing op is not differentiable. Even if you try to compute gradients for it, you will just get 0 everywhere which is not very usefull.

1 Like