I came across a piece of code while going through PyTorch, and I’m having trouble understanding certain parts of it. Specifically, I encountered confusion regarding the following code snippet:
batch_x1 = torch.randn(32,88,96)
bs, stock_num = batch_x1.shape[0], batch_x1.shape[1]
mask = torch.ones_like(batch_x1)
rand_indices = torch.rand(bs, stock_num).argsort(dim=-1)
mask_indices = rand_indices[:, :int(stock_num/2)]
batch_range = torch.arange(bs)[:, None]
mask[batch_range, mask_indices, stock_num:] = 0
enc_inp = mask * batch_x1
In this code snippet, I understand how rand_indices
is used to generate a mask mask
that randomly sets half of the values to 0. However, I’m not entirely clear on how mask[batch_range, mask_indices, stock_num:]
indexes the tensor mask
because both batch_range
and mask_indices
are two-dimensional tensors instead of the integers or lists that I’m familiar with.
Could someone please explain the indexing operation in this code? Thank you!