Torch.jit.script with v 1.12 significantly slower on CUDA GPU than v 1.11 for scientific computing (due to nvFuser ?)

Hi,

Acoording to the v1.12 changelog:
“In PyTorch 1.12, Torchscript is updating its default fuser (for Volta and later CUDA accelerators) to nvFuser, which supports a wider range of operations and is faster than NNC, the previous fuser for CUDA devices”

I’m using PyTorch for scientific computing ( I solve numerically PDEs), not for deep learning, and I have realized that since version 1.12, torch.jit.script is significantly slower on CUDA GPU, while it gives the same results on intel CPU for the computations I’m doing (example below).

I tested this on my laptop with NVIDIA RTX A3000 Laptop GPU and on a server with a NVIDIA GeForce RTX 2080 Ti. One these two machines, pytorch was installed using conda following the official pytorch instructions.

Here is an example of the computations I do with the script weno.py that implements order-5 WENO reconstruction (which is, roughly speaking, a non-linear interpolation in 1D with a five point stencil to prevent Gibbs oscillations):

import torch

# print('With torch.jit.script')
# @torch.jit.script
def weno5(qmm, qm, q0, qp, qpp):
    """
    Fifth-order WENO reconstruction, from:
        Efficient Implementation of Weighted ENO Schemes, Jiang and Shu,
        Journal of Computation Physics 126, 202–228 (1996)
    """
    eps = 1e-6
    qi1 = 1./3.*qmm - 7./6.*qm + 11./6.*q0
    qi2 = -1./6.*qm + 5./6.*q0 + 1./3.*qp
    qi3 = 1./3.*q0 + 5./6.*qp - 1./6.*qpp

    k1, k2 = 13./12., 0.25
    beta1 = k1 * (qmm-2*qm+q0)**2 + k2 * (qmm-4*qm+3*q0)**2
    beta2 = k1 * (qm-2*q0+qp)**2  + k2 * (qm-qp)**2
    beta3 = k1 * (q0-2*qp+qpp)**2 + k2 * (3*q0-4*qp+qpp)**2

    g1, g2, g3 = 0.1, 0.6, 0.3
    w1 = g1 / (beta1+eps)**2
    w2 = g2 / (beta2+eps)**2
    w3 = g3 / (beta3+eps)**2

    qi_weno5 = (w1*qi1 + w2*qi2 + w3*qi3) / (w1 + w2 + w3)

    return qi_weno5


if __name__ == '__main__':
    from time import time as cputime
    device = 'cuda'
    nx, ny = 1024, 1024
    x = torch.DoubleTensor(nx, ny).normal_().to(device)
    # circular padding
    x_ = torch.cat([x[:,[-2]], x[:,[-1]], x, x[:,[0]], x[:,[1]]], dim=-1)
    xmm, xm, x0, xp, xpp = x_[:,:-4], x_[:,1:-3], x_[:,2:-2], x_[:,3:-1], x_[:,4:]
    x = weno5(xmm, xm, x0, xp, xpp)

    ngridpoints = nx * ny
    n_steps = 1000
    mperf = 0


    for n in range(1, n_steps+1):
        walltime0 = cputime()
        x_ = torch.cat([x[:,[-2]], x[:,[-1]], x, x[:,[0]], x[:,[1]]], dim=-1)
        x = weno5(x_[:,:-4], x_[:,1:-3], x_[:,2:-2], x_[:,3:-1], x_[:,4:])
        walltime = cputime()
        perf = (walltime-walltime0)/ngridpoints
        mperf += perf
        print(f"\rn={n:4} perf={perf:.2e}, ({mperf/n:.3e}) s", end="")
    print()

I have two conda environments pt1.11 and pt1.12 with pytorch version 1.11 and 1.12 respectively.

On my laptop, when I run the script without torch.jit.script, I have the same average runtimes (in parenthesis) with pt1.11:

(pt1.11) ... $ python weno.py 
n=1000 perf=5.56e-09, (5.767e-09) s

and pt1.12 :

(pt1.12) ... $ python weno.py 
n=1000 perf=6.36e-09, (5.770e-09) s

When I use torch.jit.script, I have a 5x speed up with pt1.11 (as expected):

(pt1.11) ... $ python weno.py 
with torch.jit.script
n=1000 perf=8.77e-10, (1.242e-09) s

but I have a slow down with pt1.12 (!!):

(pt1.12) ... $ python weno.py 
with torch.jit.script
n=1000 perf=9.18e-09, (9.762e-09) s

I have the same issue on the server with the RTX 2080Ti GPU, and also with the latest 1.13.1 version.

Note that on CPUs, I do note have this problem neither on my latpop nor on the server.

Is it due to nvFuser ? If yes, is this expected ?

Sincerely,

Louis

Your profiling is wrong since you are using host timers without synchronizing the device.
CUDA operations are executed asynchronously, so you would need to synchronize the code via torch.cuda.synchronize() before starting and stopping the timers.
Also note that a proper profiling would need warmup iterations (especially since scripting and code-generation will add an overhead).
Let me know if you still see this issue after fixing the profile.

Hi, I added the synchronization and warmup and it does not change the conclusion. Here is the update script:

def weno5(qmm, qm, q0, qp, qpp):
    """
    Fifth-order WENO reconstruction, from:
        Efficient Implementation of Weighted ENO Schemes, Jiang and Shu,
        Journal of Computation Physics 126, 202–228 (1996)
    """
    eps = 1e-6
    qi1 = 1./3.*qmm - 7./6.*qm + 11./6.*q0
    qi2 = -1./6.*qm + 5./6.*q0 + 1./3.*qp
    qi3 = 1./3.*q0 + 5./6.*qp - 1./6.*qpp

    k1, k2 = 13./12., 0.25
    beta1 = k1 * (qmm-2*qm+q0)**2 + k2 * (qmm-4*qm+3*q0)**2
    beta2 = k1 * (qm-2*q0+qp)**2  + k2 * (qm-qp)**2
    beta3 = k1 * (q0-2*qp+qpp)**2 + k2 * (3*q0-4*qp+qpp)**2

    g1, g2, g3 = 0.1, 0.6, 0.3
    w1 = g1 / (beta1+eps)**2
    w2 = g2 / (beta2+eps)**2
    w3 = g3 / (beta3+eps)**2

    qi_weno5 = (w1*qi1 + w2*qi2 + w3*qi3) / (w1 + w2 + w3)

    return qi_weno5



if __name__ == '__main__':
    import torch
    from time import time as cputime
    device = 'cuda'
    nx, ny = 1024, 1024
    x = torch.DoubleTensor(nx, ny).normal_().to(device)

    x_ = torch.cat([x[:,[-2]], x[:,[-1]], x, x[:,[0]], x[:,[1]]], dim=-1) # circular padding
    x = weno5(x_[:,:-4], x_[:,1:-3], x_[:,2:-2], x_[:,3:-1], x_[:,4:]) # WENO interpolation

    # weno5_jit = torch.jit.script(weno5, (x_[:,:-4], x_[:,1:-3], x_[:,2:-2], x_[:,3:-1], x_[:,4:]))
    weno5_jit = torch.jit.trace(weno5, (x_[:,:-4], x_[:,1:-3], x_[:,2:-2], x_[:,3:-1], x_[:,4:]))

    ngridpoints = nx * ny
    n_steps = 1000
    n_warmup = 100

    for weno, mess in [
            (weno5, 'weno5 no jit'),
            (weno5_jit, 'weno5 jit'),
        ]:

        print(mess)
        mperf = 0

        for n in range(n_warmup):
            x_ = torch.cat([x[:,[-2]], x[:,[-1]], x, x[:,[0]], x[:,[1]]], dim=-1)
            x = weno(x_[:,:-4], x_[:,1:-3], x_[:,2:-2], x_[:,3:-1], x_[:,4:])
        torch.cuda.synchronize()

        for n in range(1, n_steps+1):
            torch.cuda.synchronize()
            walltime0 = cputime()
            x_ = torch.cat([x[:,[-2]], x[:,[-1]], x, x[:,[0]], x[:,[1]]], dim=-1)
            x = weno(x_[:,:-4], x_[:,1:-3], x_[:,2:-2], x_[:,3:-1], x_[:,4:])
            torch.cuda.synchronize()
            walltime = cputime()
            perf = (walltime-walltime0)/ngridpoints
            mperf += perf
            print(f"\rn={n:4} perf={perf:.2e}, ({mperf/n:.3e}) s", end="")
        print()

With pytorch 1.11

(pt1.11) ... $ python weno.py
weno5 no jit
n=1000 perf=5.80e-09, (5.782e-09) s
weno5 jit
n=1000 perf=1.07e-09, (1.084e-09) s

With pytorch 1.12

(pt1.12) ... $ python weno.py
weno5 no jit
n=1000 perf=5.56e-09, (5.773e-09) s
weno5 jit
n=1000 perf=9.43e-09, (9.779e-09) s

I confirm that the slow-down is is due to nvfuser. I tested the three different fusers using torch.jit.fuser (which is not in the docs by the way TorchScript — PyTorch 1.12 documentation ):

>>>help(torch.jit.fuser)

Help on function fuser in module torch.jit._fuser:

fuser(name)
    A context manager that facilitates switching between
    backend fusers.
    
    Valid names:
    * ``fuser0`` - enables only legacy fuser
    * ``fuser1`` - enables only NNC
    * ``fuser2`` - enables only nvFuser

here is the script:

def weno5(qmm, qm, q0, qp, qpp):
    """
    Fifth-order WENO reconstruction, from:
        Efficient Implementation of Weighted ENO Schemes, Jiang and Shu,
        Journal of Computation Physics 126, 202–228 (1996)
    """
    eps = 1e-6
    qi1 = 1./3.*qmm - 7./6.*qm + 11./6.*q0
    qi2 = -1./6.*qm + 5./6.*q0 + 1./3.*qp
    qi3 = 1./3.*q0 + 5./6.*qp - 1./6.*qpp

    k1, k2 = 13./12., 0.25
    beta1 = k1 * (qmm-2*qm+q0)**2 + k2 * (qmm-4*qm+3*q0)**2
    beta2 = k1 * (qm-2*q0+qp)**2  + k2 * (qm-qp)**2
    beta3 = k1 * (q0-2*qp+qpp)**2 + k2 * (3*q0-4*qp+qpp)**2

    g1, g2, g3 = 0.1, 0.6, 0.3
    w1 = g1 / (beta1+eps)**2
    w2 = g2 / (beta2+eps)**2
    w3 = g3 / (beta3+eps)**2

    qi_weno5 = (w1*qi1 + w2*qi2 + w3*qi3) / (w1 + w2 + w3)

    return qi_weno5

if __name__ == '__main__':
    import torch
    from time import time as cputime
    device = 'cuda'
    nx, ny = 1024, 1024
    x = torch.DoubleTensor(nx, ny).normal_().to(device)
    x_ = torch.cat([x[:,[-2]], x[:,[-1]], x, x[:,[0]], x[:,[1]]], dim=-1) # circular padding
    x = weno5(x_[:,:-4], x_[:,1:-3], x_[:,2:-2], x_[:,3:-1], x_[:,4:]) # WENO interpolation
    ngridpoints = nx * ny
    n_steps = 1000
    n_warmup = 100

    for fuser, fuser_name in [
            ('fuser0', 'legacy'),
            ('fuser1', 'NNC'),
            ('fuser2', 'nvfuser')
        ]:
        print(f'torch.jit with {fuser_name} fuser (torch.jit.fuser("{fuser}"))')

        with torch.jit.fuser(fuser):
            weno5_jit = torch.jit.trace(weno5, (x_[:,:-4], x_[:,1:-3], x_[:,2:-2], x_[:,3:-1], x_[:,4:]))
            mperf = 0
            for n in range(-n_warmup, n_steps+1):
                torch.cuda.synchronize()
                walltime0 = cputime()
                x_ = torch.cat([x[:,[-2]], x[:,[-1]], x, x[:,[0]], x[:,[1]]], dim=-1)
                x = weno5_jit(x_[:,:-4], x_[:,1:-3], x_[:,2:-2], x_[:,3:-1], x_[:,4:])
                torch.cuda.synchronize()
                walltime = cputime()
                if n >= 1:
                    perf = (walltime-walltime0)/ngridpoints
                    mperf += perf
                    print(f"\rn={n:4} perf={perf:.2e}, ({mperf/n:.3e}) s", end="")
            print()

and I get with pytorch 1.12 the same as pytorch 1.11:

(pt1.12)... $ python -i weno.py
torch.jit with legacy fuser (torch.jit.fuser("fuser0"))
n=1000 perf=5.50e-09, (5.705e-09) s
torch.jit with NNC fuser (torch.jit.fuser("fuser1"))
n=1000 perf=1.07e-09, (1.079e-09) s
torch.jit with nvfuser fuser (torch.jit.fuser("fuser2"))
n=1000 perf=9.44e-09, (9.872e-09) s

I guess it is not expected…