Hi,
Acoording to the v1.12 changelog:
“In PyTorch 1.12, Torchscript is updating its default fuser (for Volta and later CUDA accelerators) to nvFuser, which supports a wider range of operations and is faster than NNC, the previous fuser for CUDA devices”
I’m using PyTorch for scientific computing ( I solve numerically PDEs), not for deep learning, and I have realized that since version 1.12, torch.jit.script is significantly slower on CUDA GPU, while it gives the same results on intel CPU for the computations I’m doing (example below).
I tested this on my laptop with NVIDIA RTX A3000 Laptop GPU and on a server with a NVIDIA GeForce RTX 2080 Ti. One these two machines, pytorch was installed using conda following the official pytorch instructions.
Here is an example of the computations I do with the script weno.py that implements order-5 WENO reconstruction (which is, roughly speaking, a non-linear interpolation in 1D with a five point stencil to prevent Gibbs oscillations):
import torch
# print('With torch.jit.script')
# @torch.jit.script
def weno5(qmm, qm, q0, qp, qpp):
"""
Fifth-order WENO reconstruction, from:
Efficient Implementation of Weighted ENO Schemes, Jiang and Shu,
Journal of Computation Physics 126, 202–228 (1996)
"""
eps = 1e-6
qi1 = 1./3.*qmm - 7./6.*qm + 11./6.*q0
qi2 = -1./6.*qm + 5./6.*q0 + 1./3.*qp
qi3 = 1./3.*q0 + 5./6.*qp - 1./6.*qpp
k1, k2 = 13./12., 0.25
beta1 = k1 * (qmm-2*qm+q0)**2 + k2 * (qmm-4*qm+3*q0)**2
beta2 = k1 * (qm-2*q0+qp)**2 + k2 * (qm-qp)**2
beta3 = k1 * (q0-2*qp+qpp)**2 + k2 * (3*q0-4*qp+qpp)**2
g1, g2, g3 = 0.1, 0.6, 0.3
w1 = g1 / (beta1+eps)**2
w2 = g2 / (beta2+eps)**2
w3 = g3 / (beta3+eps)**2
qi_weno5 = (w1*qi1 + w2*qi2 + w3*qi3) / (w1 + w2 + w3)
return qi_weno5
if __name__ == '__main__':
from time import time as cputime
device = 'cuda'
nx, ny = 1024, 1024
x = torch.DoubleTensor(nx, ny).normal_().to(device)
# circular padding
x_ = torch.cat([x[:,[-2]], x[:,[-1]], x, x[:,[0]], x[:,[1]]], dim=-1)
xmm, xm, x0, xp, xpp = x_[:,:-4], x_[:,1:-3], x_[:,2:-2], x_[:,3:-1], x_[:,4:]
x = weno5(xmm, xm, x0, xp, xpp)
ngridpoints = nx * ny
n_steps = 1000
mperf = 0
for n in range(1, n_steps+1):
walltime0 = cputime()
x_ = torch.cat([x[:,[-2]], x[:,[-1]], x, x[:,[0]], x[:,[1]]], dim=-1)
x = weno5(x_[:,:-4], x_[:,1:-3], x_[:,2:-2], x_[:,3:-1], x_[:,4:])
walltime = cputime()
perf = (walltime-walltime0)/ngridpoints
mperf += perf
print(f"\rn={n:4} perf={perf:.2e}, ({mperf/n:.3e}) s", end="")
print()
I have two conda environments pt1.11 and pt1.12 with pytorch version 1.11 and 1.12 respectively.
On my laptop, when I run the script without torch.jit.script, I have the same average runtimes (in parenthesis) with pt1.11:
(pt1.11) ... $ python weno.py
n=1000 perf=5.56e-09, (5.767e-09) s
and pt1.12 :
(pt1.12) ... $ python weno.py
n=1000 perf=6.36e-09, (5.770e-09) s
When I use torch.jit.script, I have a 5x speed up with pt1.11 (as expected):
(pt1.11) ... $ python weno.py
with torch.jit.script
n=1000 perf=8.77e-10, (1.242e-09) s
but I have a slow down with pt1.12 (!!):
(pt1.12) ... $ python weno.py
with torch.jit.script
n=1000 perf=9.18e-09, (9.762e-09) s
I have the same issue on the server with the RTX 2080Ti GPU, and also with the latest 1.13.1 version.
Note that on CPUs, I do note have this problem neither on my latpop nor on the server.
Is it due to nvFuser ? If yes, is this expected ?
Sincerely,
Louis