I am trying to use torch.index_select
but it does not seems to work okay for dimension more than 2.
for the following example I expect the output to be in shape of 3x1x5x5
but it gives me 3x3x5x5
.
x = torch.randn(3,3,5, 5)
indices = torch.tensor([1,1, 2])
print(torch.index_select(x, 1, indices).shape)
torch.Size([3, 3, 5, 5])
I want to get the 1st
channel or x
in batch 1
and 1st
channel of x
in batch 2
and 2nd
channel of x
in batch 3