So I have this tensor t of size a x b x c
I have another tensor u of size a. Each element of u is an index of the element I want from t’s second dimension. So I want to apply the mask and get something that’s a x c.
How can I do this? Please let me know if I can provide any clarifications.
This code should work:
a, b, c = 2, 3, 4 t = torch.arange(a*b*c).view(a, b, c) print(t) > tensor([[[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]], [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]]) u = torch.tensor([0, 2]) res = t[torch.arange(a), u] print(res) > tensor([[ 0, 1, 2, 3], [20, 21, 22, 23]])