Very slow torch.median() compared to CuPy

Disclaimer: I didn’t thoroughly read the source code, I just git grep-ed and googled some information. So what I say might be wrong.

PyTorchv1.8’s median calls sort internally:

On the other hand, CuPyv8.5’s median calls partition internally:

In NumPy, sort's order is O(n^2) in the worst case while partition, O(n) according to https://stackoverflow.com/a/43589598/8335699. And if this is true to PyTorch and CuPy, the gap you found can be possible.

1 Like