How to index by int8 tensor?

I recently learned that PyTorch does not support indexing by anything other than long, int64, or bool tensors. I have an application with a multi-GB int8 indexing tensor which doesn’t fit into my GPU’s 12GB memory if I have to convert to int64.

For context, this is for an implementation of a volumetric 3D raytracer where I precompute the indices where every ray intersects the voxels and the intersection length.

By precomputing the indices and intersection lengths, the actual raytracing is extremely simple:

result = t.sum(volume3d[i_x, i_y, i_z] * lengths, axis=-1)

i_x, i_y and i_z are currently int64, but I could fit 8x as many rays if they were int8.

Is there any nice way to achieve the above without resorting to writing CUDA code?

Bumping this. Here’s a complete example where I have implemented the above using Numba. Unfortunately my naïve Numba code seems much slower than PyTorch:

#!/usr/bin/env python3

import time
from contexttimer import Timer
from numba import cuda, void, int64, float32
import torch as t


spec = {'device': 'cuda'}

# volume being raytraced
shape = 50
d = t.rand((shape, shape, shape), **spec)
# width of detector
num_pix = 512
# maximum number of intersection points of each ray (this is specific to a spherical coord. system)
num_points = 2 * d.shape[0] + 2 * d.shape[1] + d.shape[2]
# voxel indices where rays intersect (placeholder)
ind = t.randint(shape, (num_pix, num_pix, num_points, 3), **spec)
# intersection lengths of rays with voxels (placeholder)
lens = t.rand(num_pix, num_pix, num_points, **spec)

# ----- PyTorch Raytracing -----

r, e, a = ind.moveaxis(-1, 0)

with Timer(prefix='PyTorch'):
    # raytracing inner-product
    # look up voxel indices for each ray and multiply by intersection length, then sum
    result_torch = (d[r, e, a] * lens).sum(axis=-1)

# ----- Numba Raytracing -----

@cuda.jit(void(float32[:, :, :], int64[:, :, :], int64[:, :, :], int64[:, :, :], float32[:, :, :], float32[:, :]))
def raytrace(d, r, e, a, lens, result):
    """Unrolled version of PyTorch inner-product"""
    x, y = cuda.grid(2)
    if x < r.shape[0] and y < r.shape[1]:
        inner_product = 0
        for i in range(r.shape[2]):
            r_ind = r[x, y, i]
            e_ind = e[x, y, i]
            a_ind = a[x, y, i]
            len_ = lens[x, y, i]
            inner_product += d[r_ind, e_ind, a_ind] * len_
        result[x, y] = inner_product

# copy arrays to GPU
d_c = cuda.to_device(d)
r_c = cuda.to_device(r)
e_c = cuda.to_device(e)
a_c = cuda.to_device(a)
lens_c = cuda.to_device(lens)
result_numba_c = cuda.to_device(t.empty((num_pix, num_pix)))

with Timer(prefix='Numba'):
    # use a single block for each ray
    raytrace[(num_pix, num_pix), (1, 1)](d_c, r_c, e_c, a_c, lens_c, result_numba_c)

result_numba = result_numba_c.copy_to_host()

# ----- Compare Numerical Result -----

print('PyTorch result:', float(result_torch.sum()))
print('Numba result:', result_numba.sum())

PyTorch took 0.001 seconds
Numba took 0.085 seconds
PyTorch result: 16390173.0
Numba result: 16390176.0