Shuffling the columns of a matrix independently

hi, assume I have a (m,n) Tensor, say mat.
I would like to have each of its n columns shuffled in a different manner.

a naive implementation would be, mat being the (m,n) Tensor:

res = mat.clone()
for i in range(res.shape[1]):
    ind = torch.randperm(res.shape[0], device=device)
    res[:, i] = res[ind, i]

can I do the same without the for loop in a vectorized manner ?

1 Like

Hi!

One way would be to use advanced indexing and the stdlib function random.shuffle. I used it on a list rather than a call to torch.arange as shuffle seems to go against torch's semantics under the hood.

mat = torch.linspace(1, 16, 16).view(4, 4)
tensor([[ 1.,  2.,  3.,  4.],
        [ 5.,  6.,  7.,  8.],
        [ 9., 10., 11., 12.],
        [13., 14., 15., 16.]])
col_idxs = list(range(mat.shape[1]))
random.shuffle(col_idxs)
mat = mat[:, torch.tensor(col_idxs)]
tensor([[ 2.,  1.,  4.,  1.],
        [ 6.,  5.,  8.,  5.],
        [10.,  9., 12.,  9.],
        [14., 13., 16., 13.]])

Hope that helps!

Hi Antoine (and tymokvo)!

I can’t* think of a way to do this using built-in tensor
functions without a loop.

*) Well, actually I can, but with a cost in efficiency.

Try this (for m = 4, n = 3):

import torch
mat = torch.tensor ([[11.0, 12, 13],[21, 22, 23],[31, 32, 33],[41, 42, 43]])
ind = torch.rand (4, 3).argsort (dim = 0)
res = torch.zeros (4, 3).scatter_ (0, ind, mat)

The computational time complexity of your task
should be m * n. (You have n columns and the
cost of randomly permuting a length-m column
is m.)

But the cost of sorting a length-m column
is m * log (m), so my scheme has cost
n * m * log (m).

The point is that I can’t figure out how to get
the randomly permuted columns of indices
without a loop or using the sort trick.

(Note that tymokvo’s approach is applying the same
random permutation to each of the rows. Antoine is
asking for distinct random permutations for (in his
case) each of the columns, as his loop-based solution
does. Also, for reasons I don’t understand – tymokvo’s
code looks right for what it does – the final result in
tymokvo’s post has a duplicated column, (1, 5, 9, 13),
and a missing column (3, 7, 11, 15).)

Have fun!

K. Frank

2 Likes

thanks both,

@KFrank your solution does work but indeed, on a (1000, 10000) problem, I get :

scatter: 1.084769 ms

and using naive with for loops:

naive: 0.394080 ms