Hi,
I want to sort a tensor and transform on it. In the end, I need to resume the tensor to original order.
The operation is batchwise. Are there any efficiency method to do that without “for loop”?
x = torch.randn(4, 10)
w = torch.randn(4, 10, 10)
x_t = torch.zeros_like(x)
values, indices = torch.sort(x)
values_t = torch.bmm(values.reshape(4, 1, 10), w) #batchwise transform
for b in range(x.size(0)):
x_t[b, indices[b]] = values_t[b]
Thanks