Hi
I am finding that although torch.jit.script
makes CPU code about 8x faster, it makes GPU code about 5x slower, compared to code on each device that is not compiled with torch.jit.script
.
How do I fix this - am I using the JIT wrong somehow? Test code below.
import torch
import numpy as np
size=1000000 # data size
def test(device,torchscript):
def torch_call(x,mask,a,b):
# some simulated workload...
x[mask]=1
torch.bucketize(a,b)
if torchscript:
torch_call = torch.jit.script(torch_call)
# create data
x = torch.zeros(size)
mask = torch.randint(2,(size,))==1
a = torch.from_numpy(np.random.random(size))
b = torch.from_numpy(np.linspace(0,1,1000))
if device=="cuda":
x = x.cuda()
mask = mask.cuda()
a = a.cuda()
b = b.cuda()
# time torch_call
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
torch_call(x,mask,a,b) # warmup call
start.record()
torch_call(x,mask,a,b)
end.record()
torch.cuda.synchronize()
print (f"{device} {type(torch_call)=} time={start.elapsed_time(end)}")
test("cuda",False)
test("cuda",True)
test("cpu",False)
test("cpu",True)
Results:
cuda type(torch_call)=<class 'function'> time=0.16105599701404572
cuda type(torch_call)=<class 'torch.jit.ScriptFunction'> time=0.7593600153923035
cpu type(torch_call)=<class 'function'> time=83.50847625732422
cpu type(torch_call)=<class 'torch.jit.ScriptFunction'> time=10.096927642822266