Faster quantile computation with libtorch

The implementation of torch::quantile does a full sort which makes it significantly slower than say the implementation in numpy which uses partition.

I don’t see a partition equivalent in libtorch (C++)… am I missing something?

import time
import torch
import numpy as np

t = torch.rand(1_000_000)
a = t.numpy()

start = time.time()
t.median()
end = time.time()
print("torch median", end - start)

start = time.time()
t.quantile(torch.tensor([0.25, 0.75]))
end = time.time()
print("torch quant ", end - start)

start = time.time()
np.median(a)
end = time.time()
print("numpy median", end - start)

start = time.time()
np.quantile(a, [0.25, 0.75])
end = time.time()
print("numpy quant ", end - start)

Results

torch median 0.013309478759765625
torch quant  0.10049819946289062
numpy median 0.012769222259521484
numpy quant  0.014006376266479492

Using std::nth_element I can get a ~3x boost over torch::quantile but still behind the performance of numpy.

torch::Tensor quantile(const torch::Tensor t, const torch::Tensor q) {
    assert(t.dtype().name() == "float");
    assert(q.dtype().name() == "float");
    if (!torch::equal(q, std::get<0>(q.sort()))) {
        throw std::runtime_error("quantiles q are not sorted");
    }

    auto tmp = t.clone();
    auto res = torch::empty_like(q);

    auto start = tmp.data_ptr<float>();
    auto end = tmp.data_ptr<float>() + tmp.size(0);

    for (int i = 0; i < q.size(0); i++) {
        auto m = tmp.data_ptr<float>() + static_cast<size_t>((tmp.size(0) - 1) * q[i].item<float>());
        std::nth_element(start, m, end);
        res[i] = *m;
        start = m;
    }

    return res;
}