Indexing 4D tensor

Hi, I want to implement the torch.gather function, but I am confused about indexing.

In my case, I have 4D tensor which has the size of [B, N, K, C]
I want to sample S points from N. I have an indices tensor that has size of [B, S]
In the end, I want a tensor that has size of [B, S, K, C]

How can I achieve this?

This should work:

B, N, K, C, S = 2, 3, 4, 5, 6

x = torch.randn(B, N, K, C)
idx = torch.randint(0, N, (B, S))

out = x[torch.arange(x.size(0)).unsqueeze(1), idx]
# torch.Size([2, 6, 4, 5]) = [B, S, K, C]

It seems to be working, thank you so much!