selection_randn=torch.randn(4,4,3,4)
curr_idx=torch.tensor([2,3,1,0])
selection_randn[:,curr_idx,:,:].shape
I want to select 2,3,1,0 value from each row on dim 1 .
let me know how to do that.
selection_randn=torch.randn(4,4,3,4)
curr_idx=torch.tensor([2,3,1,0])
selection_randn[:,curr_idx,:,:].shape
I want to select 2,3,1,0 value from each row on dim 1 .
let me know how to do that.
Expect size is 4,3,4 but i get same size as input
I think this will work
selection_randn[range(len(curr_idx)),curr_idx]
shape =4,3,4