Hi Wasabi!
Use gather()
:
>>> import torch
>>> torch.__version__
'1.9.0'
>>> _ = torch.manual_seed (2021)
>>> a = 2
>>> b = 3
>>> c = 5
>>> X = torch.arange (a * b * c).reshape (a, b, c)
>>> X
tensor([[[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14]],
[[15, 16, 17, 18, 19],
[20, 21, 22, 23, 24],
[25, 26, 27, 28, 29]]])
>>> P = torch.randn (a, b).argsort (dim = 1)
>>> P
tensor([[2, 1, 0],
[1, 0, 2]])
>>> Y = X.gather (1, P.unsqueeze (-1).repeat (1, 1, c))
>>> Y
tensor([[[10, 11, 12, 13, 14],
[ 5, 6, 7, 8, 9],
[ 0, 1, 2, 3, 4]],
[[20, 21, 22, 23, 24],
[15, 16, 17, 18, 19],
[25, 26, 27, 28, 29]]])
>>> _ = Y.zero_()
>>> for i in range(a):
... Y[i] = X[i][P[i]]
...
>>> Y
tensor([[[10, 11, 12, 13, 14],
[ 5, 6, 7, 8, 9],
[ 0, 1, 2, 3, 4]],
[[20, 21, 22, 23, 24],
[15, 16, 17, 18, 19],
[25, 26, 27, 28, 29]]])
Best.
K. Frank