Torch is slow compared to numpy

Hi everyone,

I created a small benchmark to compare different options we have for a larger software project. In this benchmark I implemented the same algorithm in numpy/cupy, pytorch and native cpp/cuda. The benchmark is attached below.

In all tests numpy was significantly faster than pytorch. Is there any reason for this or am I using any pytorch operations the wrong way?

For N=500 I got the following results

  1. native cpp/cuda 0.5s (CPU), 0.2s (GPU)
  2. numpy, 2.3s (CPU)
  3. TorchScript 8.8s (CPU)/30s (GPU)
  4. pytorch 11.7s (CPU) / 36.3s (GPU)
  5. cupy 69.2 (GPU)

For N=10000 I got the following results

  1. native cpp/cuda 8s (CPU), 0.2s (GPU)
  2. numpy, 24s (CPU)
  3. TorchScript 30s (CPU)/31s (GPU)
  4. pytorch 35s (CPU) / 37s (GPU)
  5. cupy 69.2 (GPU)

Note that cpp+libtorch had comparable results as torchscript.

Here is the benchmal in case anyone else is interested in reproducing the results

import torch
from torch.types import Device

import pytest
import numpy

RES_1 = [71, 28, 89, 54, 71.5, 97.5, 54.5, 43, 23.5, 71]
RES_2 = [71.5, 97.5, 54.5, 43, 23.5, 71, 28, 89, 54, 71.5]

DEVICES = [torch.device("cpu")]
if torch.cuda.is_available():
    DEVICES.append(torch.device("cuda"))


@pytest.fixture(params=DEVICES, ids=[t.type for t in DEVICES])
def device(request) -> Device:
    return request.param


@pytest.fixture(params=[500, 10000])
def size(request) -> int:
    return request.param


def numpy_cupy_implementation(
    device: Device, N: int = 500, T: int = 24 * 60 * 60
) -> numpy.ndarray:

    if device.type == "cuda":
        import cupy as cunp
    else:
        import numpy as cunp

    indices = cunp.arange(N, dtype=cunp.int64)
    dummy = indices.copy()  # dummy tensor to have more operations
    location = cunp.zeros(N, dtype=cunp.float32)
    length = cunp.arange(N, dtype=cunp.float32)
    length %= 9
    length += 1
    sub_tensor = cunp.empty(N, dtype=cunp.float32)
    sub_tensor.fill(-0.5)
    for _ in range(T):
        location[indices] += length[dummy[dummy[indices]]]
        mask = location > 50
        location[dummy[indices[mask]]] += sub_tensor[mask]
        ind_mul = dummy[indices[location < 10]]
        location[ind_mul] = location[ind_mul] * 2

        location %= 100

    if device.type == "cuda":
        location = cunp.asarray(location)
    return location


def pytorch_implementation(
    device: Device, N: int = 500, T: int = 24 * 60 * 60
) -> torch.Tensor:
    indices = torch.arange(N, device=device, dtype=torch.int64)
    dummy = indices.clone()  # dummy tensor to have more operations
    location = torch.zeros(N, device=device, dtype=torch.float32)
    length = torch.arange(N, device=device, dtype=torch.float32).fmod_(9).add_(1)
    sub_tensor = torch.empty(N, device=device, dtype=torch.float32).fill_(-0.5)
    for _ in range(T):
        location.index_add_(
            0,
            indices,
            length.index_select(
                0, dummy.index_select(0, dummy.index_select(0, indices))
            ),
        )
        mask = location.gt(50)
        location.index_add_(
            0,
            dummy.index_select(0, indices.masked_select(mask)),
            sub_tensor.masked_select(mask),
        )
        ind_mul = dummy.index_select(0, indices.masked_select(location.lt(10)))
        location.index_copy_(0, ind_mul, location.index_select(0, ind_mul).mul_(2))

        location.fmod_(100)

    return location


@pytest.fixture(
    params=[
        pytorch_implementation,
        torch.jit.script(pytorch_implementation),
        numpy_cupy_implementation,
    ],
    ids=["torch", "TorchScript", "numpy/cupy"],
)
def method(request) -> Device:
    return request.param


def test_pytorch_implementation_correctness(device):
    location = pytorch_implementation(device)
    numpy.testing.assert_array_equal(location[:10].cpu().numpy(), RES_1)
    numpy.testing.assert_array_equal(location[-10:].cpu().numpy(), RES_2)


def test_cupy_implementation_correctness(device):
    location = numpy_cupy_implementation(device)
    numpy.testing.assert_array_equal(location[:10], RES_1)
    numpy.testing.assert_array_equal(location[-10:], RES_2)


@pytest.mark.benchmark(
    min_time=0.1, max_time=0.5, min_rounds=1, disable_gc=True, warmup=False
)
def test_pytorch_implementation_performance(benchmark, device, method, size):
    benchmark(method, device, size)

Hi,

I am not sure how the pytest benchmark utilities are working but they are most likely not doing proper cuda synchronization. So you most likely want to add the proper calls there (for pytorch torch.cuda.synchronize(), not sure for cupy).

Also maybe you want to remove all the data initialization from the benchmark? Or do you actually care about that?
These will add extra variance as they need to perform cpu/gpu synchronization.

Also code-wise, the problem is that your two implementation do different things so that might explain the perf difference between the two.
Have you tries running the numpy code with pytorch directly (the content of the for loop). All these advanced indexing ops should work properly afaik

Hi,

thank you for your answer!

Also maybe you want to remove all the data initialization from the benchmark? Or do you actually care about that?

I somewhat care about this. But as we do ~86000 loops the initialization shouldn’t have any impact on the endresult.

Also code-wise, the problem is that your two implementation do different things so that might explain the perf difference between the two.
Have you tries running the numpy code with pytorch directly (the content of the for loop). All these advanced indexing ops should work properly afaik

I think they are doing the same thing. I copied the pytorch code and replaced it with the numpy functionality. I also added verification that the endresults are the same for all methods. (test_pytorch_implementation_correctness and test_cupy_implementation_correctness)

Copying the “numpy loop” over makes the results much worse (only tested on cpu):

  • TorchScript 15s (N=500)/ 77s(N=10000)
  • pytorch 24s (N=500) / 87s (N=10000)

This fits with my previous experience that using the pytorch functions is a lot faster than the python operations.

I am not sure how the pytest benchmark utilities are working but they are most likely not doing proper cuda synchronization. So you most likely want to add the proper calls there (for pytorch torch.cuda.synchronize(), not sure for cupy).

I will check this, thanks!

Thanks for the details.

I don’t have a full answer but the remarks I would have are:

  • These indexing ops do not get as much love in PyTorch as core NN code (like linear/conv) and so we might be able to optimize them. If you find one in particular that is especially lacking, you can open an issue and we should be able to improve that.
  • For small size, we know that PyTorch has much more “framework overhead” than numpy. This is mostly due to the fact that we have extra features (like autograd, different device backend, etc) that make going to the low level code a bit longer (order of microseconds). This is not visible for large ops, but very small ops in a loop will suffer heavily from that.

Thank you for your answer!

I expected some overhead, but this was a bit more than I though.

  • These indexing ops do not get as much love in PyTorch as core NN code (like linear/conv) and so we might be able to optimize them. If you find one in particular that is especially lacking, you can open an issue and we should be able to improve that.

Just to confirm: I will check individual operations and then open issues for them?

Yes, if you find individual ops that are slower than they should be, you should open an issue for each op with the corresponding benchmark code and result.
That way, we can discuss what should be done for each op independently of the others.