How to select specific vector in 3D tensor beautifully?

There is a 3D tensor(batch_size * seq_len * hidden_dim), and I wanna get the different hidden vector for each sequence, do you have a beautiful method to do it except use for loop?

Tensor:
[[[1, 2, 3],
[4, 5, 6],
[7, 8, 9]],

[[11, 12, 13],
[14, 15, 16],
[17, 18, 19]]]

Target matrix:
[[4, 5, 6],
[17, 18, 19]]

and I have the index[1, 2] for select the specific vector, now I’m using for loop to get every vector for each sequence. And I found torch.gather and torch.index_select may be not suitable for my question.

Thank you for help!

2 Likes

This should work:

x = torch.tensor([[[1, 2, 3],
                   [4, 5, 6],
                   [7, 8, 9]],
                  [[11, 12, 13],
                   [14, 15, 16],
                   [17, 18, 19]]])

idx = torch.tensor([1, 2])
x[torch.arange(x.size(0)), idx]
3 Likes

Thanks, I’ve tried slicing: x[:, idx], but it’s not correct. So using indexing is correct, thank you again.

thanks from 2022 XD, it works pretty well !

Hi ptrblck, I have the similar situation. However, I have boolean index mats

[[False, True, False],
[False, True, True]]
I would like select corresponding columns:

[4,5,6]

[14,15,16]
[17,18,19]

Then padding them with -1:

The expected results:
[[[4,5,6],
[-1,-1,-1]
[-1, -1, -1]],

[[14, 15, 16],
[17, 18, 19],
[-1, -1, -1]]]

Do you know if it is possible to do it except use for loop?

Something like this should work:

x = torch.tensor([[[1, 2, 3],
                   [4, 5, 6],
                   [7, 8, 9]],
                  [[11, 12, 13],
                   [14, 15, 16],
                   [17, 18, 19]]])

mask = torch.tensor([[False, True, False],
                     [False, True, True]])

res = torch.ones_like(x) * -1
idx1 = mask.nonzero()[:, 0]
idx2 = torch.cat([torch.arange(a) for a in mask.cumsum(1).max(1).values])
res[idx1, idx2] = x[mask]

print(res)
# tensor([[[ 4,  5,  6],
#          [-1, -1, -1],
#          [-1, -1, -1]],

#         [[14, 15, 16],
#          [17, 18, 19],
#          [-1, -1, -1]]])