>>> import torch
>>> x = torch.arange(0, 3).view(1,3,1).expand(2,3,4) + 1
>>> permutation_indexes = torch.tensor([
>>> [[0,0,1,1],
>>> [1,2,0,2],
>>> [2,1,2,0]],
>>> [[2,2,0,0],
>>> [0,1,1,2],
>>> [1,0,2,1]]
>>> ])
>>> torch.all((permutation_indexes + 1) == x.gather(dim=1, index=permutation_indexes))
tensor(True)
As you can see in the code snippet above, we can gather elements from a tensor ‘x’ with an index tensor ‘permutation_indexes’. However, I need to also be able to put another tensor ‘y’ into that 'gather’ed tensor locations:
>>> y = torch.arange(0, 3).view(1,3,1).expand_as(x)
>>> print(y)
tensor([[[0, 0, 0, 0],
[1, 1, 1, 1],
[2, 2, 2, 2]],
[[0, 0, 0, 0],
[1, 1, 1, 1],
[2, 2, 2, 2]]])
>>> inverse_permutation_indexes = x.gather_put(dim=1, index=permutation_indexes, tensor=y)
>>> print(inverse_permutation_indexes)
tensor([[[0, 0, 1, 2],
[1, 2, 0, 0],
[2, 1, 2, 1]],
[[1, 2, 0, 0],
[2, 1, 1, 2],
[0, 0, 2, 1]]])
>>> torch.all(x == x.gather(dim=1, index=permutation_indexes).gather(dim=1, index=inverse_permutation_indexes))
tensor(True)
Obviously there is no function in PyTorch named ‘gather_put’ as above. How can I do the above without using for loops for all other dimensions (dimension 0 and 2)?
NOTE: This question sprung up from implementing inverse permutation idea. If you follow the link, the idea is pretty simple for a 1-dimensional tensor (vector or list). But it is nearly impossible to do with a tensor of arbitrary dimensionality (2D, 3D and higher) without using serialized (for loop) code.