Consider
import torch
import time
torch.set_float32_matmul_precision('high')
@torch.compile
def product_loop_compiled(A, B, iter):
Y = torch.zeros_like(A, device='cuda')
for _ in range(iter):
X = A @ B
Y += X
return Y
def product_loop_normal(A, B, iter):
Y = torch.zeros_like(A, device='cuda')
for _ in range(iter):
X = A @ B
Y += X
return Y
def testitest(iter, compile=False, n=50):
A = torch.randn(n, n, device='cuda')
B = torch.randn(n, n, device='cuda')
if compile:
product_loop_compiled(A, B, 100) # burn in
else:
Y = product_loop_normal(A, B, 100) # burn in
torch.cuda.synchronize()
start = time.time()
if compile:
Y = product_loop_compiled(A, B, iter)
else:
Y = product_loop_normal(A, B, iter)
torch.cuda.synchronize()
duration = time.time() - start
print(duration / iter * 1000, 'ms')
return Y
testitest(5000, compile=True)
testitest(5000, compile=False)
Running it gives
40.336214542388916 ms
0.013336086273193359 ms
It looks like the compiled code is 4000x slower. How is that possible?
I’ve heard of compiled code sometimes being slower in torch, but 4000x make me think I’m the one doing something wrong.
This is with a NVIDIA GeForce RTX 3060, 12GB, python 3.11.6 and torch 2.2.0+cu121
What am I missing?
(edit: 2.3.0 gives the same type of time discrepancy)