# Shuffling the columns of a matrix independently

hi, assume I have a `(m,n)` Tensor, say `mat`.
I would like to have each of its `n` columns shuffled in a different manner.

a naive implementation would be, `mat` being the `(m,n)` Tensor:

``````res = mat.clone()
for i in range(res.shape):
ind = torch.randperm(res.shape, device=device)
res[:, i] = res[ind, i]
``````

can I do the same without the `for` loop in a vectorized manner ?

1 Like

Hi!

One way would be to use advanced indexing and the stdlib function `random.shuffle`. I used it on a list rather than a call to `torch.arange` as `shuffle` seems to go against `torch`'s semantics under the hood.

``````mat = torch.linspace(1, 16, 16).view(4, 4)
``````
``````tensor([[ 1.,  2.,  3.,  4.],
[ 5.,  6.,  7.,  8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]])
``````
``````col_idxs = list(range(mat.shape))
random.shuffle(col_idxs)
mat = mat[:, torch.tensor(col_idxs)]
``````
``````tensor([[ 2.,  1.,  4.,  1.],
[ 6.,  5.,  8.,  5.],
[10.,  9., 12.,  9.],
[14., 13., 16., 13.]])
``````

Hope that helps!

Hi Antoine (and tymokvo)!

I can’t* think of a way to do this using built-in tensor
functions without a loop.

*) Well, actually I can, but with a cost in efficiency.

Try this (for `m = 4, n = 3`):

``````import torch
mat = torch.tensor ([[11.0, 12, 13],[21, 22, 23],[31, 32, 33],[41, 42, 43]])
ind = torch.rand (4, 3).argsort (dim = 0)
res = torch.zeros (4, 3).scatter_ (0, ind, mat)
``````

should be `m * n`. (You have `n` columns and the
cost of randomly permuting a length-`m` column
is `m`.)

But the cost of sorting a length-`m` column
is `m * log (m)`, so my scheme has cost
`n * m * log (m)`.

The point is that I can’t figure out how to get
the randomly permuted columns of indices
without a loop or using the sort trick.

(Note that tymokvo’s approach is applying the same
random permutation to each of the rows. Antoine is
asking for distinct random permutations for (in his
case) each of the columns, as his loop-based solution
does. Also, for reasons I don’t understand – tymokvo’s
code looks right for what it does – the final result in
tymokvo’s post has a duplicated column, (1, 5, 9, 13),
and a missing column (3, 7, 11, 15).)

Have fun!

K. Frank

2 Likes

thanks both,

@KFrank your solution does work but indeed, on a `(1000, 10000)` problem, I get :

scatter: 1.084769 ms

and using naive with `for` loops:

naive: 0.394080 ms