Result of indexing depends on datatypes of indexes?

I just came upon this:

$ idx_long = torch.tensor(2).type('torch.LongTensor')
$ idx_byte = torch.tensor(2).type('torch.ByteTensor')
$ torch.eye(4)[idx_long,:]
tensor([0., 0., 1., 0.])

while

$ torch.eye(4)[idx_byte,:]
tensor([[[1., 0., 0., 0.],
         [0., 1., 0., 0.],
         [0., 0., 1., 0.],
         [0., 0., 0., 1.]]])

So the result of the indexing depends on the datatype of the indexing tensor. Why?

Hi Sascha!

The short story is that indexing with a ByteTensor (dtype = torch.uint8) is
(although deprecated) boolean indexing, which works differently than indexing
with positional integer indices.

Three things are going on: Your ByteTensor gets cast (at least in effect) to
bool; boolean indexing works differently than integer indexing; and your index
tensors are zero-dimensional tensors, which may make the expected shapes
of the result tensors less obvious.

Pytorch didn’t used to have boolean tensors so byte tensors were used for
logical operations. Such use has now been deprecated.

Pytorch doesn’t document its advanced indexing directly, but, instead, refers
to numpy’s documentation.

Consider:

>>> import torch
>>> print (torch.__version__)
2.0.0
>>>
>>> t_byte = torch.tensor ([1, 0, 2, 0, 3], dtype = torch.uint8)
>>> t_bool = t_byte.bool()
>>> t_long = t_byte.long()
>>>
>>> t_byte
tensor([1, 0, 2, 0, 3], dtype=torch.uint8)
>>> t_bool
tensor([ True, False,  True, False,  True])
>>> t_long
tensor([1, 0, 2, 0, 3])
>>>
>>> t = torch.tensor ([11.0, 12.0, 13.0, 14.0, 15.0])
>>> t
tensor([11., 12., 13., 14., 15.])
>>>
>>> t[t_bool]   # boolean indexing -- note shape is [3]
tensor([11., 13., 15.])
>>> t[t_long]   # regular indexing -- shape is [5]
tensor([12., 11., 13., 11., 14.])
>>>
>>> t[t_byte]   # byte gets cast to bool (but byte indexing is deprecated)
<stdin>:1: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead. (Triggered internally at C:\cb\pytorch_1000000000000\work\aten\src\ATen/native/IndexingUtils.h:28.)
tensor([11., 13., 15.])

Best.

K. Frank