Indexing Tensors with a two-dimensional tensor

I came across a piece of code while going through PyTorch, and I’m having trouble understanding certain parts of it. Specifically, I encountered confusion regarding the following code snippet:

    batch_x1 = torch.randn(32,88,96)
    bs, stock_num = batch_x1.shape[0], batch_x1.shape[1]
    mask = torch.ones_like(batch_x1)
    rand_indices = torch.rand(bs, stock_num).argsort(dim=-1)
    mask_indices = rand_indices[:, :int(stock_num/2)]
    batch_range = torch.arange(bs)[:, None]
    mask[batch_range, mask_indices, stock_num:] = 0
    enc_inp = mask * batch_x1

In this code snippet, I understand how rand_indices is used to generate a mask mask that randomly sets half of the values to 0. However, I’m not entirely clear on how mask[batch_range, mask_indices, stock_num:] indexes the tensor mask because both batch_range and mask_indices are two-dimensional tensors instead of the integers or lists that I’m familiar with.
Could someone please explain the indexing operation in this code? Thank you!

Creating a smaller example and printing the intermediates might help understanding the indexing better:

batch_x1 = torch.arange(2*3*4).view(2, 3, 4)
print(batch_x1)
# tensor([[[ 0,  1,  2,  3],
#          [ 4,  5,  6,  7],
#          [ 8,  9, 10, 11]],

#         [[12, 13, 14, 15],
#          [16, 17, 18, 19],
#          [20, 21, 22, 23]]])


bs, stock_num = batch_x1.shape[0], batch_x1.shape[1]
mask = torch.ones_like(batch_x1)
rand_indices = torch.rand(bs, stock_num).argsort(dim=-1)
print(rand_indices)
# tensor([[0, 2, 1],
#         [2, 0, 1]])

mask_indices = rand_indices[:, :int(stock_num/2)]
print(mask_indices)
# tensor([[0],
#         [2]])

batch_range = torch.arange(bs)[:, None]
print(batch_range)
# tensor([[0],
#         [1]])

mask[batch_range, mask_indices, stock_num:] = 0
print(mask)
# tensor([[[1, 1, 1, 0],
#          [1, 1, 1, 1],
#          [1, 1, 1, 1]],

#         [[1, 1, 1, 1],
#          [1, 1, 1, 1],
#          [1, 1, 1, 0]]])

enc_inp = mask * batch_x1

Here you are using advanced indexing as described here to avoid repeating the indices and use broadcasting instead.
The numpy example indexing the corner values shows the same usage.

In your example you are setting:

  • batch0, row0, columns3:
  • batch1, row2, columns3:

to zero.

Thanks a lot sir, your reply did help me and I understand it now!