How to quickly resume batchwise sorted tensor


I want to sort a tensor and transform on it. In the end, I need to resume the tensor to original order.
The operation is batchwise. Are there any efficiency method to do that without “for loop”?

x = torch.randn(4, 10)
w = torch.randn(4, 10, 10)
x_t = torch.zeros_like(x)

values, indices = torch.sort(x)
values_t = torch.bmm(values.reshape(4, 1, 10), w) #batchwise transform

for b in range(x.size(0)):
    x_t[b, indices[b]] = values_t[b]



You can use .scatter_ to replace your for loop.

1 Like