Index select from tensors

selection_randn=torch.randn(4,4,3,4)
curr_idx=torch.tensor([2,3,1,0])

selection_randn[:,curr_idx,:,:].shape

I want to select 2,3,1,0 value from each row on dim 1 .
let me know how to do that.

Expect size is 4,3,4 but i get same size as input

I think this will work


selection_randn[range(len(curr_idx)),curr_idx]
shape =4,3,4