Batched shuffling of feature vectors

Hi, Is there an idiomatic way to do a batched shuffle of feature vectors?

I want to generate K negative samples per batch of data. i.e, given a tensor with shape (Batch x Features), I want to generate noise with the shape (K x Batch x Features), where the features in the noise are shuffled versions of the original data. e.g. given a batch like [ [ a b c ] [ a d e ] [ b d e ] ], I want to generate two negative examples per datum like so:

 a b c     a c b | c b a
 a d e ->  a e d | d e a 
 b d e     b d e | e b d

I’ve tried doing this using randperm, but it is prohibitively slow, especially on the GPU, and requires a lot of logistics and scaffolding.

I’m not sure, how exactly you are shuffling your input. Is there any specific logic or are you just randomly sampling shuffle indices?
As far as I understand, randperm seems to work but is too slow?

I want to generate shuffles on the feature level over a batch.
So given a datum like [a b c] i want to generate k shuffles of it, e.g. two shuffled versions: [ [ a c b ] [c b a ] ], with this generalized over batched data.

The issue is that the call to randperm only generates one permutation, so I need to make (K x Batch) calls to randperm (which is very slow) and then index out the shuffles from the original data.

Now that i think about it, it wouldn’t be a problem if randperm could return multiple permutations.

I recently wanted to do the same thing. In case anyone is interested, the following is my solution:

import torch


n_batch = 8
n_feat = 11

rand = torch.rand(n_batch, n_feat)
batch_rand_perm = rand.argsort(dim=1)
print(batch_rand_perm)
tensor([[ 1,  2,  4,  6,  0,  9,  3,  7,  8,  5, 10],
        [10,  6,  9,  2,  7,  4,  3,  5,  1,  0,  8],
        [ 8,  7,  4,  6,  0,  1,  9, 10,  2,  5,  3],
        [ 1,  3,  5,  4,  6,  8,  2,  9,  0,  7, 10],
        [ 3, 10,  7,  1,  4,  5,  0,  8,  9,  6,  2],
        [ 2, 10,  0,  1,  6,  8,  7,  9,  5,  3,  4],
        [ 5,  9,  8, 10,  0,  1,  6,  7,  4,  2,  3],
        [ 0,  9,  8,  4,  1,  3,  2,  7,  5, 10,  6]])
8 Likes