I’ve been playing a bit with tensor expressions and noticed that it can be very slow on cuda (I used Pytorch 1.12 with CUDA 11.3). For example, here I reimplemented some basic functionality of torch.gather
and benchmarked it.
import torch
import torch.utils.benchmark as benchmark
import torch._C._te as te
def construct_gather(n: int, backend, dtype=torch.float32):
A = te.BufHandle("A", [n, n], dtype)
INDEX = te.BufHandle("INDEX", [n, n], torch.long)
B = te.BufHandle("B", [n, n], dtype)
i = te.VarHandle("i", torch.int)
j = te.VarHandle("j", torch.int)
store = te.Store.make(A, [i, j], B.load([INDEX.load([j + i*n]), j])) # flatten index manually for now
for_j = te.For.make(j, te.ExprHandle.int(0), te.ExprHandle.int(n), store)
for_i = te.For.make(i, te.ExprHandle.int(0), te.ExprHandle.int(n), for_j)
loopnest = te.LoopNest(te.Block([for_i]), [A, B, INDEX])
loopnest.prepare_for_codegen()
stmt = te.simplify(loopnest.root_stmt())
return te.construct_codegen(backend, stmt, [A, B, INDEX])
if __name__ == "__main__":
n = 2000
torch.manual_seed(42)
d = torch.randn((n, n))
index = (torch.rand((n, n)) * (n-2)).long()
to = torch.zeros((n, n))
print("Pytorch CPU")
t0 = benchmark.Timer("torch.gather(d, 0, index, out=to)", globals={"d": d, "index": index, "to": to})
print(t0.timeit(2))
print("NNC/LLVM CPU")
gather_llvm = construct_gather(n, "llvm")
t0 = benchmark.Timer("gather.call([to, d, index])", globals={"d": d, "index": index, "to": to, "gather": gather_llvm})
print(t0.timeit(2))
d = d.to(0)
index = index.to(0)
to = to.to(0)
print("Pytorch GPU")
t1 = benchmark.Timer("torch.gather(d, 0, index, out=to)", globals={"d": d, "index": index, "to": to})
print(t1.timeit(2))
print("NNC/CUDA GPU")
gather_cuda = construct_gather(n, "cuda")
t0 = benchmark.Timer("gather.call([to, d, index])", globals={"d": d, "index": index, "to": to, "gather":gather_cuda})
print(t0.timeit(2))
And I get an output like this:
Pytorch CPU
torch.gather(d, 0, index, out=to)
50.85 ms
1 measurement, 2 runs , 1 thread
NNC/LLVM CPU
gather.call([to, d, index])
21.75 ms
1 measurement, 2 runs , 1 thread
Pytorch GPU
torch.gather(d, 0, index, out=to)
10.67 ms
1 measurement, 2 runs , 1 thread
NNC/CUDA GPU
gather.call([to, d, index])
444.69 ms
1 measurement, 2 runs , 1 thread
Now, I wouldn’t expect my naive reimplementation to be faster than torch.gather
but maybe not an order of magnitude off? I’ve played a bit with loopnest.tile
to optimise it but nothing seemed to have a really noticable effect. Any ideas what I might be missing?