Very slow torch.median() compared to CuPy

To my surprise torch.median() is well over an order of magnitude slower than the equivalent cupy.median() on matrices of dimension 1000x1000 or more. It also gets worse as the matrix size grows. This is even more surprising given that unlike CuPy, PyTorch returns element N // 2 - 1 of the sorted array as median for arrays with an even number of entries instead of the average of the 2 middle numbers, which should lead to a speed bump. On the other hand, it seems to be less memory hungry than CuPy (about 25% less for 10,000x10,000 matrices). Any insight what might cause this?

The following code can be used to reproduce the results:

import cupy as cp
import numpy as np
from timeit import default_timer as timer
import torch

dim = [100, 1000, 5000, 10000]
n_runs = 10
n_warmup = 2
n_tot = n_runs + n_warmup

device = torch.device('cuda')

# timers
t_start = torch.cuda.Event(enable_timing=True)
t_end = torch.cuda.Event(enable_timing=True)

# PyTorch
for d in dim:
    t_pt = np.zeros(n_tot)
    for n in range(n_tot):
        np.random.seed(0)
        x = np.random.randn(*(d, d)).astype(np.float32)
        x_pt = torch.from_numpy(x).to(device)
        torch.cuda.synchronize()
        t_start.record()
        torch.median(x_pt)
        t_end.record()
        torch.cuda.synchronize()
        t_pt[n] = 1e-3 * t_start.elapsed_time(t_end)  # 1e-3 b/c prints in ms
    t_mean, t_std = t_pt[n_warmup:].mean(), t_pt[n_warmup].std()
    print(f'PyTorch: {d} x {d}: {t_mean:.5f} +- {t_std:.5f}')

# CuPy
for d in dim:
    t_cp = np.zeros(n_tot)
    for n in range(n_tot):
        np.random.seed(0)
        x = np.random.randn(*(d, d)).astype(np.float32)
        x_cp = cp.asarray(x)
        cp.cuda.Stream.null.synchronize()
        t_start = timer()
        cp.median(x_cp)
        cp.cuda.Stream.null.synchronize()
        t_cp[n] = timer() - t_start
    t_mean, t_std = t_cp[n_warmup:].mean(), t_cp[n_warmup].std()
    print(f'CuPy: {d} x {d}: {t_mean:.5f} +- {t_std:.5f}')

Which on a RTX2080Ti, CuPy 8.5.0 for CUDA 10.0 and torch 1.8.0 gives as mean (in sec) and standard deviation over 10 runs after the warmup phase:

PyTorch: 100 x 100: 0.00016 +- 0.00000
PyTorch: 1000 x 1000: 0.00186 +- 0.00000
PyTorch: 5000 x 5000: 0.05195 +- 0.00000
PyTorch: 10000 x 10000: 0.22688 +- 0.00000
CuPy: 100 x 100: 0.00024 +- 0.00000
CuPy: 1000 x 1000: 0.00041 +- 0.00000
CuPy: 5000 x 5000: 0.00470 +- 0.00000
CuPy: 10000 x 10000: 0.01795 +- 0.00000

Disclaimer: I didn’t thoroughly read the source code, I just git grep-ed and googled some information. So what I say might be wrong.

PyTorchv1.8’s median calls sort internally:

On the other hand, CuPyv8.5’s median calls partition internally:

In NumPy, sort's order is O(n^2) in the worst case while partition, O(n) according to https://stackoverflow.com/a/43589598/8335699. And if this is true to PyTorch and CuPy, the gap you found can be possible.

1 Like