Indexing a Tensor's Second Dimension

So I have this tensor t of size a x b x c

I have another tensor u of size a. Each element of u is an index of the element I want from t’s second dimension. So I want to apply the mask and get something that’s a x c.

How can I do this? Please let me know if I can provide any clarifications.

This code should work:

a, b, c = 2, 3, 4
t = torch.arange(a*b*c).view(a, b, c)
print(t)
> tensor([[[ 0,  1,  2,  3],
           [ 4,  5,  6,  7],
           [ 8,  9, 10, 11]],

          [[12, 13, 14, 15],
           [16, 17, 18, 19],
           [20, 21, 22, 23]]])

u = torch.tensor([0, 2])

res = t[torch.arange(a), u]
print(res)
> tensor([[ 0,  1,  2,  3],
          [20, 21, 22, 23]])
1 Like