Shuffling every slice of a tensor using different permutations

I have a tensor X of shape (a, b, c) and a matrix of permutation (not a permutation matrix) P of shape (a,b), where each row of P is an output of torch.randperm(). I want to shuffle X in the following way:

for i in range(a):
    Y[i] = X[i][P[i]]
return Y

what’s the best way to achieve this? Thanks.

Hi Wasabi!

Use gather():

>>> import torch
>>> torch.__version__
'1.9.0'
>>> _ = torch.manual_seed (2021)
>>> a = 2
>>> b = 3
>>> c = 5
>>> X = torch.arange (a * b * c).reshape (a, b, c)
>>> X
tensor([[[ 0,  1,  2,  3,  4],
         [ 5,  6,  7,  8,  9],
         [10, 11, 12, 13, 14]],

        [[15, 16, 17, 18, 19],
         [20, 21, 22, 23, 24],
         [25, 26, 27, 28, 29]]])
>>> P = torch.randn (a, b).argsort (dim = 1)
>>> P
tensor([[2, 1, 0],
        [1, 0, 2]])
>>> Y = X.gather (1, P.unsqueeze (-1).repeat (1, 1, c))
>>> Y
tensor([[[10, 11, 12, 13, 14],
         [ 5,  6,  7,  8,  9],
         [ 0,  1,  2,  3,  4]],

        [[20, 21, 22, 23, 24],
         [15, 16, 17, 18, 19],
         [25, 26, 27, 28, 29]]])
>>> _ = Y.zero_()
>>> for i in range(a):
...     Y[i] = X[i][P[i]]
...
>>> Y
tensor([[[10, 11, 12, 13, 14],
         [ 5,  6,  7,  8,  9],
         [ 0,  1,  2,  3,  4]],

        [[20, 21, 22, 23, 24],
         [15, 16, 17, 18, 19],
         [25, 26, 27, 28, 29]]])

Best.

K. Frank

Hi KFrank, thanks for your answer. Is there a way that I can use gather() without repeat P along certain axis? Say that my X is a tensor with more axis (a, b, c, d, …), then is there a way to naturally broadcast P?

Hi Wasabi!

No. Quoting from the documentation for torch.gather():

Note that input and index do not broadcast against each other.

(Also, I’m not aware of any way to do what you want with pytorch
indexing operations other than gather(), although I may have
overlooked something.)

Best.

K. Frank

I’m old enough to remember what we did before broadcasting, and that was to use .expand. Expand will create a view(!) with stride 0 for the expanded dimension, so it is much more efficient that repeat. (Personally, I don’t think I remember when I last saw a use of repeat that should not be expand instead.)

Y = X.gather(1, P.unsqueeze(-1).expand(-1, -1, c))

Gather is the most elegant way to do it by far, but there also is the option of explicitly indexing which might be the most immediate transcription of the loop you wrote:

i0 = torch.arange(a)[:, None, None].expand(-1, b, c)
i1 = P.unsqueeze(-1).expand(-1, -1, c)
i2 = torch.arange(c)[None, None, :].expand(a, b, -1)
print((X[i0, i1, i2] == Y).all())

(But I don’t know where the third dimension came from, you can leave that out to get closer to the initial message.)

Best regards

Thomas

1 Like

Thanks! I guess it’s not that bad to repeat axis after all.