Let’s say we have 4D tensor (B, C,w,h). I want to sort this tensor based on the average value of the channels avg_pool(wxh). The average of wxh used to sort the tensor along the C axis. Would you suggest the efficient implementation of this without needing a loop?

For example, using torch.sort(tensor, dim=1) affects the rows of the 2D tensor (wxh). But I don’t want to swap any rows or columns, I only need to sort the whole 2D matrix based on their averaged_pooled values.