Put a tensor into a gathered tensor (inverse permutation)

>>> import torch
>>> x = torch.arange(0, 3).view(1,3,1).expand(2,3,4) + 1
>>> permutation_indexes = torch.tensor([
>>>     [[0,0,1,1], 
>>>      [1,2,0,2], 
>>>      [2,1,2,0]], 
>>>     [[2,2,0,0], 
>>>      [0,1,1,2], 
>>>      [1,0,2,1]]
>>> ])
>>> torch.all((permutation_indexes + 1) == x.gather(dim=1, index=permutation_indexes))
tensor(True)

As you can see in the code snippet above, we can gather elements from a tensor ‘x’ with an index tensor ‘permutation_indexes’. However, I need to also be able to put another tensor ‘y’ into that 'gather’ed tensor locations:

>>> y = torch.arange(0, 3).view(1,3,1).expand_as(x)
>>> print(y)
tensor([[[0, 0, 0, 0],
         [1, 1, 1, 1],
         [2, 2, 2, 2]],
        [[0, 0, 0, 0],
         [1, 1, 1, 1],
         [2, 2, 2, 2]]])
>>> inverse_permutation_indexes = x.gather_put(dim=1, index=permutation_indexes, tensor=y)
>>> print(inverse_permutation_indexes)
tensor([[[0, 0, 1, 2],
         [1, 2, 0, 0],
         [2, 1, 2, 1]],
        [[1, 2, 0, 0],
         [2, 1, 1, 2],
         [0, 0, 2, 1]]])
>>> torch.all(x == x.gather(dim=1, index=permutation_indexes).gather(dim=1, index=inverse_permutation_indexes))
tensor(True)

Obviously there is no function in PyTorch named ‘gather_put’ as above. How can I do the above without using for loops for all other dimensions (dimension 0 and 2)?

NOTE: This question sprung up from implementing inverse permutation idea. If you follow the link, the idea is pretty simple for a 1-dimensional tensor (vector or list). But it is nearly impossible to do with a tensor of arbitrary dimensionality (2D, 3D and higher) without using serialized (for loop) code.

It seems rather bold to say that it’s nearly impossible to do without for loops.

For example, I can replace the line defining inverse_permutation_indexes with

inverse_permutation_indexes = torch.empty_like(permutation_indexes).scatter_(dim=1, index=permutation_indexes, src=y)

and get the computation to evaluate to the same. One could achieve a similar thing using advanced indexing.

Best regards

Thomas

1 Like

Please excuse my bold statement about performing above operation being nearly impossible, as it seemed so before reading your excellent answer. I had been looking for an answer non-stop until now. Thank you!