PyTorch 2.0 is 3x slower than 1.11 on very simple example?

Here is a very simple example of an nn.Module - @dean.p.foster and I ran it on two versions, PyTorch 1.11-cu115 and PyTorch 2.0-cu117 (two different hosts). I got a factor of 3 slowdown on PyTorch 2.0, and I couldn’t really understand what I was missing that this would be considered expected behavior -

import time

import numpy as np
import torch
print(torch.__version__)
import torch.nn as nn

class Attn(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.c_attn = nn.Linear(embed_dim, 16 * embed_dim)
        
    def forward(self, x):
        k = self.c_attn(x)
        return k

yy = Attn(512)
inp = torch.randn((2048, 512), device='cuda:2')
yy = yy.to('cuda:2')

inp = torch.randn((2048, 512), device='cuda:2')
sm = 0

for i in range(550):
    start = time.time()
    z = yy(inp)
    end = time.time()
    #print(i, '  ', (end - start))
    sm += end - start

    z2 = torch.sum(z)

    start = time.time()
    z2.backward()
    torch.cuda.synchronize(2)
    end = time.time()
    #print(i, '  ', (end - start))
    sm += end - start

torch.cuda.synchronize(2)
print(i, ' Total: ', sm)

For PyTorch 2.0, the output is:

2.0.1
549  Total:  1.1321918964385986

For PyTorch 1.11.0, the output is:

1.11.0+cu115
549  Total:  0.3390023708343506

Adding more info

(torch2) > python -c "import torch;print(torch.__config__.show(), torch.cuda.get_device_properties(0))"
PyTorch built with:
  - GCC 9.3
  - C++ Version: 201703
  - Intel(R) oneAPI Math Kernel Library Version 2022.1-Product Build 20220311 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v2.7.3 (Git Hash 6dbeffbae1f23cbbeae17adb7b5b13f1f37c080e)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - LAPACK is enabled (usually provided by MKL)
  - NNPACK is enabled
  - CPU capability usage: AVX2
  - CUDA Runtime 11.8
  - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_61,code=sm_61;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_90,code=sm_90;-gencode;arch=compute_37,code=compute_37
  - CuDNN 8.7
  - Magma 2.6.1
  - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=11.8, CUDNN_VERSION=8.7.0, CXX_COMPILER=/opt/rh/devtoolset-9/root/usr/bin/c++, CXX_FLAGS= -D_GLIBCXX_USE_CXX11_ABI=0 -fabi-version=11 -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOROCTRACER -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wunused-local-typedefs -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_DISABLE_GPU_ASSERTS=ON, TORCH_VERSION=2.0.1, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF, 
 _CudaDeviceProperties(name='NVIDIA A100-SXM4-80GB', major=8, minor=0, total_memory=81228MB, multi_processor_count=108)
(base)  * * * * * * * * *   /home/ubuntu    * * * * * * * *  
(torch2) > uname -r
5.15.0-1038-aws
(torch2) > nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2022 NVIDIA Corporation
Built on Wed_Sep_21_10:33:58_PDT_2022
Cuda compilation tools, release 11.8, V11.8.89
Build cuda_11.8.r11.8/compiler.31833905_0

Running the same on PyTorch master produced the same result:

2.1.0a0+git03c9321
549  Total:  1.1332781314849854
(base)  * * * * * * * * *   /home/ubuntu/pytorch    * * * * * * * *
(torch2) > locate libcudnn
/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/lib/libcudnn.so.8.7.0
/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/lib/libcudnn_adv_infer.so.8.7.0
/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/lib/libcudnn_adv_train.so.8.7.0
/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/lib/libcudnn_cnn_infer.so.8.7.0
/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/lib/libcudnn_cnn_train.so.8.7.0
/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/lib/libcudnn_ops_infer.so.8.7.0
/opt/conda/envs/pytorch/lib/python3.10/site-packages/torch/lib/libcudnn_ops_train.so.8.7.0
/opt/conda/pkgs/pytorch-2.0.1-aws_py3.10_cuda11.8_cudnn8.7.0_0/lib/python3.10/site-packages/torch/lib/libcudnn.so.8.7.0
/opt/conda/pkgs/pytorch-2.0.1-aws_py3.10_cuda11.8_cudnn8.7.0_0/lib/python3.10/site-packages/torch/lib/libcudnn_adv_infer.so.8.7.0
/opt/conda/pkgs/pytorch-2.0.1-aws_py3.10_cuda11.8_cudnn8.7.0_0/lib/python3.10/site-packages/torch/lib/libcudnn_adv_train.so.8.7.0
/opt/conda/pkgs/pytorch-2.0.1-aws_py3.10_cuda11.8_cudnn8.7.0_0/lib/python3.10/site-packages/torch/lib/libcudnn_cnn_infer.so.8.7.0
/opt/conda/pkgs/pytorch-2.0.1-aws_py3.10_cuda11.8_cudnn8.7.0_0/lib/python3.10/site-packages/torch/lib/libcudnn_cnn_train.so.8.7.0
/opt/conda/pkgs/pytorch-2.0.1-aws_py3.10_cuda11.8_cudnn8.7.0_0/lib/python3.10/site-packages/torch/lib/libcudnn_ops_infer.so.8.7.0
/opt/conda/pkgs/pytorch-2.0.1-aws_py3.10_cuda11.8_cudnn8.7.0_0/lib/python3.10/site-packages/torch/lib/libcudnn_ops_train.so.8.7.0
/usr/local/cuda-11.8/lib/libcudnn.so.8.7.0
/usr/local/cuda-11.8/lib/libcudnn_adv_infer.so.8.7.0
/usr/local/cuda-11.8/lib/libcudnn_adv_train.so.8.7.0
/usr/local/cuda-11.8/lib/libcudnn_cnn_infer.so.8.7.0
/usr/local/cuda-11.8/lib/libcudnn_cnn_train.so.8.7.0
/usr/local/cuda-11.8/lib/libcudnn_ops_infer.so.8.7.0
/usr/local/cuda-11.8/lib/libcudnn_ops_train.so.8.7.0

Your profiling is invalid since you need to synchronize the code before starting and stopping the timers.

1 Like

@ptrblck – did that too, tried a few varients. Curious what Im misunderstanding here (fully believe this is likely my error!)

import time

import numpy as np
import torch
print(torch.__version__)
import torch.nn as nn

class Attn(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.c_attn = nn.Linear(embed_dim, 16 * embed_dim)
        
    def forward(self, x):
        k = self.c_attn(x)
        return k

yy = Attn(512)
inp = torch.randn((2048, 512), device='cuda:3')
yy = yy.to('cuda:3')

inp = torch.randn((2048, 512), device='cuda:3')
sm = 0

for i in range(550):
    torch.cuda.synchronize(3)
    start = time.time()
    z = yy(inp)
    end = time.time()
    #print(i, '  ', (end - start))
    sm += end - start

    z2 = torch.sum(z)

    start = time.time()
    z2.backward()
    torch.cuda.synchronize(3)
    end = time.time()
    
    #print(i, '  ', (end - start))
    sm += end - start

torch.cuda.synchronize(3)
print(i, ' Total: ', sm)

I did that too - gives the same result:

2.1.0a0+git03c9321
549  Total:  1.1322011947631836

For 1.11 I get:

1.11.0+cu115
549 Total: 0.3438262939453125

Breakdown by forward and backward pass:

PyTorch 2.0:

549  Forward Total:  0.030092954635620117
549  Backward Total:  1.101416826248169

PyTorch 1.11:

549  Forward Total:  0.05405449867248535
549  Backward Total:  0.300870418548584

Cross-post from here without a proper follow-up:
TF32 was disabled in 2.0; enabling it restores the performance.