Hi everyone,
I created a small benchmark to compare different options we have for a larger software project. In this benchmark I implemented the same algorithm in numpy/cupy, pytorch and native cpp/cuda. The benchmark is attached below.
In all tests numpy was significantly faster than pytorch. Is there any reason for this or am I using any pytorch operations the wrong way?
For N=500 I got the following results
- native cpp/cuda 0.5s (CPU), 0.2s (GPU)
- numpy, 2.3s (CPU)
- TorchScript 8.8s (CPU)/30s (GPU)
- pytorch 11.7s (CPU) / 36.3s (GPU)
- cupy 69.2 (GPU)
For N=10000 I got the following results
- native cpp/cuda 8s (CPU), 0.2s (GPU)
- numpy, 24s (CPU)
- TorchScript 30s (CPU)/31s (GPU)
- pytorch 35s (CPU) / 37s (GPU)
- cupy 69.2 (GPU)
Note that cpp+libtorch had comparable results as torchscript.
Here is the benchmal in case anyone else is interested in reproducing the results
import torch
from torch.types import Device
import pytest
import numpy
RES_1 = [71, 28, 89, 54, 71.5, 97.5, 54.5, 43, 23.5, 71]
RES_2 = [71.5, 97.5, 54.5, 43, 23.5, 71, 28, 89, 54, 71.5]
DEVICES = [torch.device("cpu")]
if torch.cuda.is_available():
DEVICES.append(torch.device("cuda"))
@pytest.fixture(params=DEVICES, ids=[t.type for t in DEVICES])
def device(request) -> Device:
return request.param
@pytest.fixture(params=[500, 10000])
def size(request) -> int:
return request.param
def numpy_cupy_implementation(
device: Device, N: int = 500, T: int = 24 * 60 * 60
) -> numpy.ndarray:
if device.type == "cuda":
import cupy as cunp
else:
import numpy as cunp
indices = cunp.arange(N, dtype=cunp.int64)
dummy = indices.copy() # dummy tensor to have more operations
location = cunp.zeros(N, dtype=cunp.float32)
length = cunp.arange(N, dtype=cunp.float32)
length %= 9
length += 1
sub_tensor = cunp.empty(N, dtype=cunp.float32)
sub_tensor.fill(-0.5)
for _ in range(T):
location[indices] += length[dummy[dummy[indices]]]
mask = location > 50
location[dummy[indices[mask]]] += sub_tensor[mask]
ind_mul = dummy[indices[location < 10]]
location[ind_mul] = location[ind_mul] * 2
location %= 100
if device.type == "cuda":
location = cunp.asarray(location)
return location
def pytorch_implementation(
device: Device, N: int = 500, T: int = 24 * 60 * 60
) -> torch.Tensor:
indices = torch.arange(N, device=device, dtype=torch.int64)
dummy = indices.clone() # dummy tensor to have more operations
location = torch.zeros(N, device=device, dtype=torch.float32)
length = torch.arange(N, device=device, dtype=torch.float32).fmod_(9).add_(1)
sub_tensor = torch.empty(N, device=device, dtype=torch.float32).fill_(-0.5)
for _ in range(T):
location.index_add_(
0,
indices,
length.index_select(
0, dummy.index_select(0, dummy.index_select(0, indices))
),
)
mask = location.gt(50)
location.index_add_(
0,
dummy.index_select(0, indices.masked_select(mask)),
sub_tensor.masked_select(mask),
)
ind_mul = dummy.index_select(0, indices.masked_select(location.lt(10)))
location.index_copy_(0, ind_mul, location.index_select(0, ind_mul).mul_(2))
location.fmod_(100)
return location
@pytest.fixture(
params=[
pytorch_implementation,
torch.jit.script(pytorch_implementation),
numpy_cupy_implementation,
],
ids=["torch", "TorchScript", "numpy/cupy"],
)
def method(request) -> Device:
return request.param
def test_pytorch_implementation_correctness(device):
location = pytorch_implementation(device)
numpy.testing.assert_array_equal(location[:10].cpu().numpy(), RES_1)
numpy.testing.assert_array_equal(location[-10:].cpu().numpy(), RES_2)
def test_cupy_implementation_correctness(device):
location = numpy_cupy_implementation(device)
numpy.testing.assert_array_equal(location[:10], RES_1)
numpy.testing.assert_array_equal(location[-10:], RES_2)
@pytest.mark.benchmark(
min_time=0.1, max_time=0.5, min_rounds=1, disable_gc=True, warmup=False
)
def test_pytorch_implementation_performance(benchmark, device, method, size):
benchmark(method, device, size)