How to rearrange tensor in a certain dimension?

For example:
Feature is [B,C,H,W], I have weights of [B,C], I want to rearrange feature in Channel dimension with descending weights for per image. My code is as follow:
I want to know that there is a more elegant way to operate all channels at the same time?
Earnestly hope !

import torch
feature = torch.rand(4,6,3,3)
weight = torch.rand(4,6)
_, indices = torch.sort(weight, dim=1, descending = True)
z1 = feature[0,indices[0],:,:]
z2 = feature[1,indices[1],:,:]
z3 = feature[2,indices[2],:,:]
z4 = feature[3,indices[3],:,:]
z = torch.stack((z1,z2,z3,z4),dim=0)

You could directly index the tensor as seen here:

feature = torch.rand(4,6,3,3)
weight = torch.rand(4,6)
_, indices = torch.sort(weight, dim=1, descending = True)
z1 = feature[0,indices[0],:,:]
z2 = feature[1,indices[1],:,:]
z3 = feature[2,indices[2],:,:]
z4 = feature[3,indices[3],:,:]
z = torch.stack((z1,z2,z3,z4),dim=0)

res = feature[torch.arange(feature.size(0)).unsqueeze(1), indices]

print((z - res).abs().max())
> tensor(0.)
2 Likes

Thank you very much for your quick reply, which is exactly what I want.