Hi Minh!

Let’s start by looking at an example of so-called *advanced indexing*

where we use two index tensors to index into a two-dimensional

tensor. This is similar to your example, except we simplify it by leaving

out the part where you use slicing (the colon, “`:`

”) on the third dimension.

The two index tensors have the same shape, but their shape is not

related to the shape of the tensor into which you are indexing. Instead,

the shape of the index tensors becomes the shape of the resulting tensor.

This is the first part of the example given below.

Then when you put the colon in for your third dimension, you are saying

that you want all of the third dimension, and advanced indexing simply

applies your two index tensors to each (two-dimensional) slice along the

third dimension.

This is done explicitly in the second part of the example.

Finally, you can (not that you should) build three index tensors so that

you can use “plain” advanced indexing to index directly into your

three-dimensional tensor without slicing.

This is illustrated in the third part of the example.

Consider:

```
>>> import torch
>>> print (torch.__version__)
1.12.0
>>>
>>> _ = torch.manual_seed (2022)
>>>
>>> points = torch.randint(0, 4, (4, 5, 2))
>>> batch_indices = torch.randint(0, 4, (4, 4, 3))
>>> idx = torch.randint(0, 4, (4, 4, 3))
>>>
>>> points2d = points[..., 0] # make a 2d tensor
>>> points2d.shape
torch.Size([4, 5])
>>>
>>> # use two index tensors to index into 2d tensor
>>> new_points2d = points2d[batch_indices, idx] # the index tensors are 3d so new_points2d is 3d
>>> new_points2d.shape
torch.Size([4, 4, 3])
>>>
>>> # use two index tensors plus slicing to index into 3d tensor
>>> new_points = points[batch_indices, idx, :]
>>> new_points.shape
torch.Size([4, 4, 3, 2])
>>>
>>> # build new_points by indexing into slices of points
>>> s0 = points[..., 0][batch_indices, idx]
>>> s1 = points[..., 1][batch_indices, idx]
>>> new_pointsB = torch.stack ((s0, s1), -1)
>>> torch.equal (new_points, new_pointsB)
True
>>>
>>> # use three index tensors to index into 3d tensor
>>> batch_indicesC = batch_indices.unsqueeze (-1).expand (-1, -1, -1, 2)
>>> idxC = idx.unsqueeze (-1).expand (-1, -1, -1, 2)
>>> slice_idxC = torch.stack ((torch.zeros (4, 4, 3).long(), torch.ones (4, 4, 3).long()), -1)
>>> batch_indicesC.shape
torch.Size([4, 4, 3, 2])
>>> idxC.shape
torch.Size([4, 4, 3, 2])
>>> slice_idxC.shape
torch.Size([4, 4, 3, 2])
>>> new_pointsC = points[batch_indicesC, idxC, slice_idxC]
>>> torch.equal (new_points, new_pointsC)
True
```

Best.

K. Frank