Hi, I want to implement the torch.gather function, but I am confused about indexing.
In my case, I have 4D tensor which has the size of [B, N, K, C]
I want to sample S points from N. I have an indices tensor that has size of [B, S]
In the end, I want a tensor that has size of [B, S, K, C]
I have a follow-up question: How do I index a more “inner” dimension?
Say, for the exact same setting only that I would like to index dimension 2 for x.
x = torch.randn(B, N, K, C)
idx = torch.randint(0, S, (B, N, S))
And slightly more related to my use case: what would I need to do if I have an index tensor idx.shape=(B, S) which I want to use for indexing dim 2 (K). Meaning that for each index in dim 0 (B), I want to index the same values in dim 2 (K), for all indices in dim 1 (N).