Torch.index_select does not behave as expected

I am trying to use torch.index_select but it does not seems to work okay for dimension more than 2.
for the following example I expect the output to be in shape of 3x1x5x5 but it gives me 3x3x5x5.

x = torch.randn(3,3,5, 5)
indices = torch.tensor([1,1, 2])
print(torch.index_select(x, 1, indices).shape)
torch.Size([3, 3, 5, 5])

I want to get the 1st channel or x in batch 1 and 1st channel of x in batch 2 and 2nd channel of x in batch 3

now sure how the index_select should work, but I think for your question you can do it like this:

b,c,w,h = x.shape
new_ind = (indices.view(-1) +  (torch.arange(0,b)*c).view(-1) ).tolist()
out = x.view(b*c,w,h)[new_ind]