Did torch.sort drastically improve in 1.10?


Did torch.sort on CUDA for large arrays drastically improve from version 1.9 to 1.10? The release notes of 1.10 don’t mention any major improvement to sorting.

I ask as the below test for sorting 10^8 floating-point numbers results in a throughput of 85 MKPS (million keys sorted per second) on torch 1.9 and 1585 MKPS on torch 1.10. That’s a massive difference. Note that I’m running the same script on the same GPU (GeForce RTX 3070), and only changing the PyTorch version.

import time
import torch

n = 1 * (10 ** 8)
x = torch.randn((n, ), device='cuda')

# Start time measurement
t = time.perf_counter()

output, _ = torch.sort(x)

# End time measurement
sort_time = time.perf_counter() - t

print(f'Latency: {sort_time}')
print(f'Throughput: {round((n / sort_time) / 1e6, 3)} MKPS')

A few comments:

  1. I suspected that lazy evaluation may be involved. To test this idea, I added print(torch.diff(output, 1).prod()) after the sort (and before the time measurement), a value that can only be computed if the array is truly sorted. The results remained roughly the same (1489 MKPS).

  2. I’ve also reproduced these results on a GTX 2080 Ti, roughly the same values as the 3070.

  3. I’ve also run the above script in a for loop and averaged the results (e.g., to avoid issues with warmup and noise).

Was the sorting mechanism indeed drastically improved, or am I missing something in this phenomenon?

1.9 had a perf bug that made sorting 1d tensors or multi-d tensors with few sorted slices very slow, this bug was fixed for 1.10 (along with minor perf improvement brought by new cuda version).

Interesting. Can you point me to the pull request for the bug?

By the way, what sorting algorithm does PyTorch implement for this scale?

Here’s a fix Bring back old algorithm for sorting on small number of segments by zasdfgbnm · Pull Request #64127 · pytorch/pytorch · GitHub, it has a link to the issue describing regression and the issue has a link to original PR introducing it.
Pytorch is calling into cub for almost all (and definitely all large) inputs.