Double indexing/slicing using tensors as indices

Hello, I’m having a hard time understanding double indexing using tensors as indices; as an example:

import torch

points = torch.randint(0, 4, (4, 5, 2))
batch_indices = torch.randint(0, 4, (4, 4, 3))
idx = torch.randint(0, 4, (4, 4, 3))

new_points = points[batch_indices, idx, :]

print('points shape: {}'.format(points.shape))
print('batch_indices shape: {}'.format(batch_indices.shape))
print('idx shape: {}'.format(idx.shape))
print('new_points shape: {}'.format(new_points.shape))

The shapes printing results are as follows:

points shape: torch.Size([4, 5, 2])
batch_indices shape: torch.Size([4, 4, 3])
idx shape: torch.Size([4, 4, 3])
new_points shape: torch.Size([4, 4, 3, 2])

Can someone help to provide an equivalent set of simpler operations?

Hi Minh!

Let’s start by looking at an example of so-called advanced indexing
where we use two index tensors to index into a two-dimensional
tensor. This is similar to your example, except we simplify it by leaving
out the part where you use slicing (the colon, “:”) on the third dimension.

The two index tensors have the same shape, but their shape is not
related to the shape of the tensor into which you are indexing. Instead,
the shape of the index tensors becomes the shape of the resulting tensor.

This is the first part of the example given below.

Then when you put the colon in for your third dimension, you are saying
that you want all of the third dimension, and advanced indexing simply
applies your two index tensors to each (two-dimensional) slice along the
third dimension.

This is done explicitly in the second part of the example.

Finally, you can (not that you should) build three index tensors so that
you can use “plain” advanced indexing to index directly into your
three-dimensional tensor without slicing.

This is illustrated in the third part of the example.

Consider:

>>> import torch
>>> print (torch.__version__)
1.12.0
>>>
>>> _ = torch.manual_seed (2022)
>>>
>>> points = torch.randint(0, 4, (4, 5, 2))
>>> batch_indices = torch.randint(0, 4, (4, 4, 3))
>>> idx = torch.randint(0, 4, (4, 4, 3))
>>>
>>> points2d = points[..., 0]   # make a 2d tensor
>>> points2d.shape
torch.Size([4, 5])
>>>
>>> # use two index tensors to index into 2d tensor
>>> new_points2d = points2d[batch_indices, idx]   # the index tensors are 3d so new_points2d is 3d
>>> new_points2d.shape
torch.Size([4, 4, 3])
>>>
>>> # use two index tensors plus slicing to index into 3d tensor
>>> new_points = points[batch_indices, idx, :]
>>> new_points.shape
torch.Size([4, 4, 3, 2])
>>>
>>> # build new_points by indexing into slices of points
>>> s0 = points[..., 0][batch_indices, idx]
>>> s1 = points[..., 1][batch_indices, idx]
>>> new_pointsB = torch.stack ((s0, s1), -1)
>>> torch.equal (new_points, new_pointsB)
True
>>>
>>> # use three index tensors to index into 3d tensor
>>> batch_indicesC = batch_indices.unsqueeze (-1).expand (-1, -1, -1, 2)
>>> idxC = idx.unsqueeze (-1).expand (-1, -1, -1, 2)
>>> slice_idxC = torch.stack ((torch.zeros (4, 4, 3).long(), torch.ones (4, 4, 3).long()), -1)
>>> batch_indicesC.shape
torch.Size([4, 4, 3, 2])
>>> idxC.shape
torch.Size([4, 4, 3, 2])
>>> slice_idxC.shape
torch.Size([4, 4, 3, 2])
>>> new_pointsC = points[batch_indicesC, idxC, slice_idxC]
>>> torch.equal (new_points, new_pointsC)
True

Best.

K. Frank

1 Like

Hello KFrank,

Thank you very much for your efforts in making such a detailed answer! I will spend time reading and understanding all of them.