Why does the indices tensor have to be Long dtype?

Hi there, I am curious about why the indices tensor must be Long (torch.int64) dtype, like the indices parameter in torch.gather and torch.take_along_dim. I think it should accept all Integer dtype.

There are two parts:

  • There may actually be cases with very large tensors where you need 64 bit indexing.
  • The dtypes essentially figures into the signature (i.e. the formal parameter types) of the computational kernels, so you need to compile separate kernels for all (combinations of) argument dtypes. This is quite an issue as is with the support of many floating point types, but it would grow completely out of control if we had arbitrary index types for everything.

So even if we might prefer to use 32 bit indexing if we can (which we do for selected GPU kernels because it has a large speed impact, but not on CPUs in general I guess), it would significantly grow the size of PyTorch and we’re aching from the size as is.

Best regards


Thank you very much for your detailed explanation!

1 Like