How to efficiently index tensor with a batch

Having a tensor with dimension (B,C,w,h). I want to sort the channels ( C ) based on the average pooled values of the tensor.
I have this code for that, but the indexing part is not working

import torch
tensor = torch.rand(2,4,3,3)
# average pooling
polling = torch.nn.AdaptiveAvgPool2d(1)
tensor_pooled = polling(tensor)
#sorting the averaged pooled values
sorted_tensor_pooled, sorted_indx = torch.sort(tensor_pooled, 1, descending=True)

# trying to sort the original tensor based on the index from the sorted avg_pool. But this is not working 
sorted_tensor = tensor[sorted_indx]

How could I do this without using loop ?