Batched shuffling of feature vectors

Hi, Is there an idiomatic way to do a batched shuffle of feature vectors?

I want to generate K negative samples per batch of data. i.e, given a tensor with shape (Batch x Features), I want to generate noise with the shape (K x Batch x Features), where the features in the noise are shuffled versions of the original data. e.g. given a batch like [ [ a b c ] [ a d e ] [ b d e ] ], I want to generate two negative examples per datum like so:

 a b c     a c b | c b a
 a d e ->  a e d | d e a 
 b d e     b d e | e b d

I’ve tried doing this using randperm, but it is prohibitively slow, especially on the GPU, and requires a lot of logistics and scaffolding.

1 Like

I’m not sure, how exactly you are shuffling your input. Is there any specific logic or are you just randomly sampling shuffle indices?
As far as I understand, randperm seems to work but is too slow?

I want to generate shuffles on the feature level over a batch.
So given a datum like [a b c] i want to generate k shuffles of it, e.g. two shuffled versions: [ [ a c b ] [c b a ] ], with this generalized over batched data.

The issue is that the call to randperm only generates one permutation, so I need to make (K x Batch) calls to randperm (which is very slow) and then index out the shuffles from the original data.

Now that i think about it, it wouldn’t be a problem if randperm could return multiple permutations.

I recently wanted to do the same thing. In case anyone is interested, the following is my solution:

import torch


n_batch = 8
n_feat = 11

rand = torch.rand(n_batch, n_feat)
batch_rand_perm = rand.argsort(dim=1)
print(batch_rand_perm)
tensor([[ 1,  2,  4,  6,  0,  9,  3,  7,  8,  5, 10],
        [10,  6,  9,  2,  7,  4,  3,  5,  1,  0,  8],
        [ 8,  7,  4,  6,  0,  1,  9, 10,  2,  5,  3],
        [ 1,  3,  5,  4,  6,  8,  2,  9,  0,  7, 10],
        [ 3, 10,  7,  1,  4,  5,  0,  8,  9,  6,  2],
        [ 2, 10,  0,  1,  6,  8,  7,  9,  5,  3,  4],
        [ 5,  9,  8, 10,  0,  1,  6,  7,  4,  2,  3],
        [ 0,  9,  8,  4,  1,  3,  2,  7,  5, 10,  6]])
8 Likes