Bumping this. Here’s a complete example where I have implemented the above using Numba. Unfortunately my naïve Numba code seems much slower than PyTorch:

```
#!/usr/bin/env python3
import time
from contexttimer import Timer
from numba import cuda, void, int64, float32
import torch as t
t.manual_seed(0)
t.cuda.empty_cache()
spec = {'device': 'cuda'}
# volume being raytraced
shape = 50
d = t.rand((shape, shape, shape), **spec)
# width of detector
num_pix = 512
# maximum number of intersection points of each ray (this is specific to a spherical coord. system)
num_points = 2 * d.shape[0] + 2 * d.shape[1] + d.shape[2]
# voxel indices where rays intersect (placeholder)
ind = t.randint(shape, (num_pix, num_pix, num_points, 3), **spec)
# intersection lengths of rays with voxels (placeholder)
lens = t.rand(num_pix, num_pix, num_points, **spec)
# ----- PyTorch Raytracing -----
r, e, a = ind.moveaxis(-1, 0)
with Timer(prefix='PyTorch'):
# raytracing inner-product
# look up voxel indices for each ray and multiply by intersection length, then sum
result_torch = (d[r, e, a] * lens).sum(axis=-1)
# ----- Numba Raytracing -----
@cuda.jit(void(float32[:, :, :], int64[:, :, :], int64[:, :, :], int64[:, :, :], float32[:, :, :], float32[:, :]))
def raytrace(d, r, e, a, lens, result):
"""Unrolled version of PyTorch inner-product"""
x, y = cuda.grid(2)
if x < r.shape[0] and y < r.shape[1]:
inner_product = 0
for i in range(r.shape[2]):
r_ind = r[x, y, i]
e_ind = e[x, y, i]
a_ind = a[x, y, i]
len_ = lens[x, y, i]
inner_product += d[r_ind, e_ind, a_ind] * len_
result[x, y] = inner_product
# copy arrays to GPU
d_c = cuda.to_device(d)
r_c = cuda.to_device(r)
e_c = cuda.to_device(e)
a_c = cuda.to_device(a)
lens_c = cuda.to_device(lens)
result_numba_c = cuda.to_device(t.empty((num_pix, num_pix)))
with Timer(prefix='Numba'):
# use a single block for each ray
raytrace[(num_pix, num_pix), (1, 1)](d_c, r_c, e_c, a_c, lens_c, result_numba_c)
cuda.synchronize()
result_numba = result_numba_c.copy_to_host()
# ----- Compare Numerical Result -----
print('PyTorch result:', float(result_torch.sum()))
print('Numba result:', result_numba.sum())
```

```
PyTorch took 0.001 seconds
Numba took 0.085 seconds
PyTorch result: 16390173.0
Numba result: 16390176.0
```