Slow NNC cuda code for tensor expressions

I’ve been playing a bit with tensor expressions and noticed that it can be very slow on cuda (I used Pytorch 1.12 with CUDA 11.3). For example, here I reimplemented some basic functionality of torch.gather and benchmarked it.

import torch
import torch.utils.benchmark as benchmark
import torch._C._te as te

def construct_gather(n: int, backend, dtype=torch.float32):
    A = te.BufHandle("A", [n, n], dtype)
    INDEX = te.BufHandle("INDEX", [n, n], torch.long)
    B = te.BufHandle("B", [n, n], dtype)
    i = te.VarHandle("i", torch.int)
    j = te.VarHandle("j", torch.int)

    store = te.Store.make(A, [i, j], B.load([INDEX.load([j + i*n]), j])) # flatten index manually for now
    for_j = te.For.make(j, te.ExprHandle.int(0), te.ExprHandle.int(n), store)
    for_i = te.For.make(i, te.ExprHandle.int(0), te.ExprHandle.int(n), for_j)

    loopnest = te.LoopNest(te.Block([for_i]), [A, B, INDEX])

    loopnest.prepare_for_codegen()
    stmt = te.simplify(loopnest.root_stmt())

    return te.construct_codegen(backend, stmt, [A, B, INDEX])

if __name__ == "__main__":
    n = 2000
    torch.manual_seed(42)
    d = torch.randn((n, n))
    index = (torch.rand((n, n)) * (n-2)).long()
    to = torch.zeros((n, n))

    print("Pytorch CPU")
    t0 = benchmark.Timer("torch.gather(d, 0, index, out=to)", globals={"d": d, "index": index, "to": to})
    print(t0.timeit(2))

    print("NNC/LLVM CPU")
    gather_llvm = construct_gather(n, "llvm")
    t0 = benchmark.Timer("gather.call([to, d, index])", globals={"d": d, "index": index, "to": to, "gather": gather_llvm})
    print(t0.timeit(2))

    d = d.to(0)
    index = index.to(0)
    to = to.to(0)

    print("Pytorch GPU")
    t1 = benchmark.Timer("torch.gather(d, 0, index, out=to)", globals={"d": d, "index": index, "to": to})
    print(t1.timeit(2))

    print("NNC/CUDA GPU")
    gather_cuda = construct_gather(n, "cuda")
    t0 = benchmark.Timer("gather.call([to, d, index])", globals={"d": d, "index": index, "to": to, "gather":gather_cuda})
    print(t0.timeit(2))

And I get an output like this:

Pytorch CPU
torch.gather(d, 0, index, out=to)
  50.85 ms
  1 measurement, 2 runs , 1 thread
  
NNC/LLVM CPU
gather.call([to, d, index])
  21.75 ms
  1 measurement, 2 runs , 1 thread
  
Pytorch GPU
torch.gather(d, 0, index, out=to)
  10.67 ms
  1 measurement, 2 runs , 1 thread
  
NNC/CUDA GPU
gather.call([to, d, index])
  444.69 ms
  1 measurement, 2 runs , 1 thread

Now, I wouldn’t expect my naive reimplementation to be faster than torch.gather but maybe not an order of magnitude off? I’ve played a bit with loopnest.tile to optimise it but nothing seemed to have a really noticable effect. Any ideas what I might be missing?