Batch version of torch.randperm

i have a batch of input vectors and i want to shuffle them before feeding to a nn.
i noticed to torch.randperm doesn’t have an option to generate multiple samples at once. in such a case, which is better in terms of computation time but also setbacks with learning:

my input is x with shape (batch_size,vec_size)

shuffled_indices = torch.empty(0,vec_size)
for i in range(batch_size):
shuffled_indices = torch.cat([shuffled_indices,torch.randperm(vec_size)])
x = x[shuffled_indices]
shuffled_indices = torch.randperm(vec_size).unsqueeze(0).repeat(batch_size,1)
x=x[shuffled_indices]

notice that these are two different approaches. in one i use a loop to generate a batch of shuffled indices, in the other i just let all samples in the batch be shuffled in the same order. i’m trying to figure out if shuffling the entire batch in one order would add an unwanted bias to the batch, and if this is less significant than the time consumption of a for loop as in example 1.

p.s
also notice i assign the new order by simply getting the shuffled indices in the original tensor, but in torch this doesn’t give the desired outcome. so that’s another bonus question - how do you treat a multidimensional tensor as coordinates to get certain elements from another multidimensional tesnor?

1 Like

You should never use that x = cat([x, y]) pattern. It does O(n^2) copying and does so in a way that shows. You can preallocate using empty and then use randperm with out on the rows.

An alternative to generate a batch in one go you might benchmark that, could be to generate a matrix of random values and sort that in one dimension. This isn’t ideal in terms of complexity but should generally work very well.

Of course, using a single permutation is faster and if you have enough minibatches it might not matter as much, but I would be very weary of spurious correlations (just like people found out that the interdependence for batch norm often does not matter as much but in some cases turns out to be really bad).

Best regards

Thomas

Hello

import torch 


x = torch.rand(3, 5)
print(x)

# CASE 1
indices = torch.argsort(torch.rand(*x.shape), dim=-1)
result = x[torch.arange(x.shape[0]).unsqueeze(-1), indices]
print(indices)
print(result)

# CASE 2
indices = torch.randperm(x.shape[1])
result = x[:, indices]
print(indices)
print(result)
3 Likes