Torch.jit.script slows down gpu

Hi

I am finding that although torch.jit.script makes CPU code about 8x faster, it makes GPU code about 5x slower, compared to code on each device that is not compiled with torch.jit.script.

How do I fix this - am I using the JIT wrong somehow? Test code below.

import torch
import numpy as np

size=1000000 # data size

def test(device,torchscript):
    def torch_call(x,mask,a,b):
        # some simulated workload...
        x[mask]=1
        torch.bucketize(a,b)

    if torchscript:
        torch_call = torch.jit.script(torch_call)

    # create data
    x = torch.zeros(size)
    mask = torch.randint(2,(size,))==1
    a = torch.from_numpy(np.random.random(size))
    b = torch.from_numpy(np.linspace(0,1,1000))

    if device=="cuda":
        x = x.cuda()
        mask = mask.cuda()
        a = a.cuda()
        b = b.cuda()

    # time torch_call
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    torch_call(x,mask,a,b) # warmup call
    start.record()
    torch_call(x,mask,a,b)
    end.record()
    torch.cuda.synchronize()
    print (f"{device} {type(torch_call)=} time={start.elapsed_time(end)}")

test("cuda",False)
test("cuda",True)
test("cpu",False)
test("cpu",True)

Results:
cuda type(torch_call)=<class 'function'> time=0.16105599701404572
cuda type(torch_call)=<class 'torch.jit.ScriptFunction'> time=0.7593600153923035
cpu type(torch_call)=<class 'function'> time=83.50847625732422
cpu type(torch_call)=<class 'torch.jit.ScriptFunction'> time=10.096927642822266