Indexing 4D tensor

Hi, I want to implement the torch.gather function, but I am confused about indexing.

In my case, I have 4D tensor which has the size of [B, N, K, C]
I want to sample S points from N. I have an indices tensor that has size of [B, S]
In the end, I want a tensor that has size of [B, S, K, C]

How can I achieve this?

This should work:

B, N, K, C, S = 2, 3, 4, 5, 6

x = torch.randn(B, N, K, C)
idx = torch.randint(0, N, (B, S))

out = x[torch.arange(x.size(0)).unsqueeze(1), idx]
print(out.shape)
# torch.Size([2, 6, 4, 5]) = [B, S, K, C]
1 Like

It seems to be working, thank you so much!

@ptrblck Thank you for your response.

I have a follow-up question: How do I index a more “inner” dimension?

Say, for the exact same setting only that I would like to index dimension 2 for x.

x = torch.randn(B, N, K, C)
idx = torch.randint(0, S, (B, N, S))

And slightly more related to my use case: what would I need to do if I have an index tensor idx.shape=(B, S) which I want to use for indexing dim 2 (K). Meaning that for each index in dim 0 (B), I want to index the same values in dim 2 (K), for all indices in dim 1 (N).

Thank you in advance

I don’t know what your expected output shape would be as the idx tensor should most likely contain values in [0, K], but this might work:

x = torch.randn(B, N, K, C)
idx = torch.randint(0, K, (B, N, S))

out = x[torch.arange(x.size(0))[:, None, None], torch.arange(x.size(1))[:, None], idx]
print(out.shape)
# torch.Size([2, 3, 6, 5]) = [B, N, S, C]
1 Like

Yes, this seems to work. Thank you very much :slight_smile: