Linear-time algorithms that are traditionally used to shuffle data on CPUs, such as the method of Fisher-Yates, might not be well suited to implementation on GPUs due to inherent sequential dependencies. Does the torch.randperm
implementation on NVIDIA gpus offer better performance than Fisher-Yates, (e.g., less that O(1), so that time increases less than linearly with length)?
I’ve done some grepping (grep -R “Algorithm of randperm”) in the source and I found two occurrences in Randperm.cuh
:
../include/ATen/native/cuda/Randperm.cuh:// See note [Algorithm of randperm]
../include/ATen/native/cuda/Randperm.cuh:// See note [Algorithm of randperm]
I can’t understand the code, and I am not certain of the meaning of
See note [Algorithm of randperm]
Does it mean one of these?
- Somewhere else the “Algorithm of randperm” is explained in a note.
- This immediate code is the implementation of the “Algorithm of randperm”, and is the “read the code” documention for the algorithm.
Hopefully, the meaning is #1. If so, where would that note be?
Update: I found it on the pytorch github repo in a file Randperm.cu
, which is not in the distribution.
Here is the note verbatim
// [Algorithm of randperm]
//
// randperm is implemented by sorting an arange tensor of size n with randomly
// generated keys. When random keys are different from each other, all different
// permutations have the same probability.
//
// However, there is a pitfall here:
// For better performance, these N random keys are generated independently,
// and there is no effort to make sure they are different at the time of generation.
// When two keys are identical, stable sorting algorithms will not permute these two keys.
// As a result, (0, 1) will appear more often than (1, 0).
//
// To overcome this pitfall we first carefully choose the number of bits in these keys,
// so that the probability of having duplicate keys is under a threshold. Let q be the
// threshold probability for having non-duplicate keys, then it can be proved that[1]
// the number of bits required is: ceil(log2(n - (6 n^2 + 1) / (12 log(q))))
//
// Then after sort, we lauch a separate kernel that additionally shuffles any islands
// of values whose keys matched. The algorithm of this kernel is as follows:
// Each thread reads its key and the keys of its neighbors to tell if it's part of an island.
// For each island, the first thread in the island sees a key match at index i+1 but not index i-1.
// This thread considers itself the "island leader". The island leader then reads more indices to
// the right to figure out how big the island is. Most likely, the island will be very small,
// just a few values. The island leader then rolls that many RNG, uses them to additionally
// shuffle values within the island using serial Fisher-Yates, and writes them out.
//
// Reference
// [1] https://osf.io/af2hy/
However, with that note alone I am still uncertain of the computation time. I’ll have to read the reference https://osf.io/af2hy/
. If that explains it, I’ll be able answer my own question.