Why pytorch 2.0 makes code slower on 2080ti?

My test code is here
`
import torch
import time

设置GPU设备

device = torch.device(“cuda:1” if torch.cuda.is_available() else “cpu”)
print(f"Using device: {device}")

创建测试数据

batch_size = 50
image_size = 224
num_classes = 1000
test_images = torch.randn((batch_size, 3, image_size, image_size)).to(device)

定义损失函数和优化器

criterion = torch.nn.CrossEntropyLoss()

定义GradScaler对象

scaler = torch.cuda.amp.GradScaler()

加载ViT-Base模型

vit = torch.hub.load(‘facebookresearch/deit:main’, ‘deit_base_patch16_224’, pretrained=False).to(device)
vit = torch.compile(vit)
vit.eval()
optimizer_vit = torch.optim.Adam(vit.parameters(), lr=0.001)

测试ViT-Base性能

start_time = time.time()
for i in range(100):
optimizer_vit.zero_grad()
outputs = vit(test_images)
loss = criterion(outputs, torch.randint(num_classes, (batch_size,), device=device))
loss.backward()
optimizer_vit.step()
end_time = time.time()
print(f"ViT-Base: {batch_size * 100 / (end_time-start_time):.2f} images/second")
`

Could you post your environment information via python -m torch.utils.collect_env and post the measured performance please?
Also, note that CUDA kernels are executed asynchronously so you would need to synchronize the code before starting and stopping the timers via torch.cuda.synchronize().

There’s a few caveats

  • Compilation takes time so you need to remove the first inference from your calculation to see what speedups to expect
  • Performance is better on newer server GPUs like A100 or A10G, 2080 speedups won’t be as substantial
  • A full repro would help, as would a performance trace of your model - there could be other passes or configs you could pass in to torch.compile that would make things faster

Thanks for your tip! I remove the first inference and provide a full repro below. By switching the lines marked with ‘^^^’, the compiled and not compiled model is compared.

import torch
import time

# 设置GPU设备
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 创建测试数据
batch_size = 50
image_size = 224
num_classes = 1000
test_images = torch.randn((batch_size, 3, image_size, image_size)).to(device)
# 定义损失函数和优化器
criterion = torch.nn.CrossEntropyLoss()

# 定义GradScaler对象
scaler = torch.cuda.amp.GradScaler()

# 加载ViT-Base模型
vit = torch.hub.load('facebookresearch/deit:main', 'deit_base_patch16_224', pretrained=False).to(device)
# vit = torch.compile(vit) ^^^
# _ = vit(test_images) ^^^
# del _ ^^^
vit.eval()
optimizer_vit = torch.optim.Adam(vit.parameters(), lr=0.001)

# 测试ViT-Base性能
torch.cuda.synchronize()
start_time = time.time()
for i in range(100):
    optimizer_vit.zero_grad()
    with torch.cuda.amp.autocast():
        outputs = vit(test_images)
        loss = criterion(outputs, torch.randint(num_classes, (batch_size,), device=device))
    scaler.scale(loss).backward()
    scaler.step(optimizer_vit)
    scaler.update()
torch.cuda.synchronize()
end_time = time.time()
print(f"ViT-Base: {batch_size * 100 / (end_time-start_time):.2f} images/second")

This will give 152 images/second vs 224 images/second on singe 2080ti repectively. The slower is mdoel compiled.

1 Like

Thanks! My env is like this

Collecting environment information...
PyTorch version: 2.0.0+cu117
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.2 LTS (x86_64)
GCC version: (Ubuntu 11.3.0-1ubuntu1~22.04) 11.3.0
Clang version: Could not collect
CMake version: version 3.26.0
Libc version: glibc-2.35

Python version: 3.9.16 (main, Mar  8 2023, 14:00:05)  [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.19.0-35-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA GeForce RTX 2080 Ti
GPU 1: NVIDIA GeForce RTX 2080 Ti

Nvidia driver version: 515.86.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
架构:                           x86_64
CPU 运行模式:                   32-bit, 64-bit
Address sizes:                   43 bits physical, 48 bits virtual
字节序:                         Little Endian
CPU:                             48
在线 CPU 列表:                  0-47
厂商 ID:                        AuthenticAMD
型号名称:                       AMD EPYC 7402 24-Core Processor
CPU 系列:                       23
型号:                           49
每个核的线程数:                 2
每个座的核数:                   24
座:                             1
步进:                           0
Frequency boost:                 enabled
CPU 最大 MHz:                   2800.0000
CPU 最小 MHz:                   1500.0000
BogoMIPS:                       5589.60
标记:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip rdpid overflow_recov succor smca sev sev_es
虚拟化:                         AMD-V
L1d 缓存:                       768 KiB (24 instances)
L1i 缓存:                       768 KiB (24 instances)
L2 缓存:                        12 MiB (24 instances)
L3 缓存:                        128 MiB (8 instances)
NUMA 节点:                      1
NUMA 节点0 CPU:                 0-47
Vulnerability Itlb multihit:     Not affected
Vulnerability L1tf:              Not affected
Vulnerability Mds:               Not affected
Vulnerability Meltdown:          Not affected
Vulnerability Mmio stale data:   Not affected
Vulnerability Retbleed:          Mitigation; untrained return thunk; SMT enabled with STIBP protection
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:        Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:        Mitigation; Retpolines, IBPB conditional, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds:             Not affected
Vulnerability Tsx async abort:   Not affected

Versions of relevant libraries:
[pip3] numpy==1.24.2
[pip3] torch==2.0.0
[pip3] torchaudio==2.0.1
[pip3] torchvision==0.15.1
[conda] numpy                     1.24.2                   pypi_0    pypi
[conda] torch                     2.0.0                    pypi_0    pypi
[conda] torchaudio                2.0.1                    pypi_0    pypi
[conda] torchvision               0.15.1                   pypi_0    pypi

@jsrdcht tested your code on A100, after adding the following

    if i == 1:
        start_time = time.time()

Code:

import time

import torch

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

batch_size = 50
image_size = 224
num_classes = 1000
test_images = torch.randn((batch_size, 3, image_size, image_size)).to(device)

criterion = torch.nn.CrossEntropyLoss()

scaler = torch.cuda.amp.GradScaler()

vit = torch.hub.load('facebookresearch/deit:main', 'deit_base_patch16_224', pretrained=False).to(device)
# vit = torch.compile(vit)
optimizer_vit = torch.optim.Adam(vit.parameters(), lr=0.001)

torch.cuda.synchronize()
for i in range(101):
    if i == 1:
        start_time = time.time()
    optimizer_vit.zero_grad()
    with torch.cuda.amp.autocast():
        outputs = vit(test_images)
        loss = criterion(outputs, torch.randint(num_classes, (batch_size,), device=device))
    scaler.scale(loss).backward()
    scaler.step(optimizer_vit)
    scaler.update()
torch.cuda.synchronize()
end_time = time.time()
print(f"ViT-Base: {batch_size * 100 / (end_time-start_time):.2f} images/second")

Results (without torch.compile)

Using device: cuda:1
Using cache found in /root/.cache/torch/hub/facebookresearch_deit_main
ViT-Base: 687.77 images/second

Results (with torch.compile)

Using device: cuda:1
Using cache found in /root/.cache/torch/hub/facebookresearch_deit_main
ViT-Base: 827.48 images/second

My environment

PyTorch version: 2.0.0+cu117
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.5 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: Could not collect
CMake version: version 3.26.0
Libc version: glibc-2.31

Python version: 3.9.5 (default, Nov 23 2021, 15:27:38)  [GCC 9.3.0] (64-bit runtime)
Python platform: Linux-4.19.93-1.nbp.el7.x86_64-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 11.7.99
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA A100-SXM4-80GB
GPU 1: NVIDIA A100-SXM4-80GB
GPU 2: NVIDIA A100-SXM4-80GB
GPU 3: NVIDIA A100-SXM4-80GB
GPU 4: NVIDIA A100-SXM4-80GB
GPU 5: NVIDIA A100-SXM4-80GB
GPU 6: NVIDIA A100-SXM4-80GB
GPU 7: NVIDIA A100-SXM4-80GB

Nvidia driver version: 510.73.08
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.5.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.5.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.5.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.5.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.5.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.5.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.5.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                    x86_64
CPU op-mode(s):                  32-bit, 64-bit
Byte Order:                      Little Endian
Address sizes:                   43 bits physical, 48 bits virtual
CPU(s):                          64
On-line CPU(s) list:             0-63
Thread(s) per core:              2
Core(s) per socket:              16
Socket(s):                       2
NUMA node(s):                    2
Vendor ID:                       AuthenticAMD
CPU family:                      23
Model:                           49
Model name:                      AMD EPYC 7302 16-Core Processor
Stepping:                        0
CPU MHz:                         3293.654
BogoMIPS:                        5988.96
Virtualization:                  AMD-V
L1d cache:                       1 MiB
L1i cache:                       1 MiB
L2 cache:                        16 MiB
L3 cache:                        256 MiB
NUMA node0 CPU(s):               0-15,32-47
NUMA node1 CPU(s):               16-31,48-63
Vulnerability Itlb multihit:     Not affected
Vulnerability L1tf:              Not affected
Vulnerability Mds:               Not affected
Vulnerability Meltdown:          Not affected
Vulnerability Spec store bypass: Vulnerable
Vulnerability Spectre v1:        Vulnerable: __user pointer sanitization and usercopy barriers only; no swapgs barriers
Vulnerability Spectre v2:        Vulnerable, IBPB: disabled, STIBP: disabled
Vulnerability Tsx async abort:   Not affected
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid extd_apicid aperfmperf pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate sme ssbd sev ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif umip rdpid overflow_recov succor smca

Versions of relevant libraries:
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.24.2
[pip3] torch==2.0.0+cu117
[pip3] torchaudio==2.0.1+cu117
[pip3] torchinfo==1.7.2
[pip3] torchvision==0.15.1+cu117
[conda] Could not collect

Strangely, the code gets super slow without AMP, with or without torch.compile() on torch 2.0.0

# torch 2.0.0, without AMP, without torch.compile()
ViT-Base: 145.43 images/second
# torch 2.0.0, without AMP, with torch.compile()
ViT-Base: 147.82 images/second
# torch 1.11.0, without AMP, without torch.compile()
ViT-Base: 455.63 images/second

Environment of torch 1.11.0

Collecting environment information...
PyTorch version: 1.11.0+cu113
Is debug build: False
CUDA used to build PyTorch: 11.3
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.4 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: Could not collect
CMake version: version 3.16.3
Libc version: glibc-2.31

Python version: 3.9.5 (default, Nov 23 2021, 15:27:38)  [GCC 9.3.0] (64-bit runtime)
Python platform: Linux-4.19.93-1.nbp.el7.x86_64-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 11.3.109
GPU models and configuration: 
GPU 0: NVIDIA A100-SXM4-80GB
GPU 1: NVIDIA A100-SXM4-80GB
GPU 2: NVIDIA A100-SXM4-80GB
GPU 3: NVIDIA A100-SXM4-80GB
GPU 4: NVIDIA A100-SXM4-80GB
GPU 5: NVIDIA A100-SXM4-80GB
GPU 6: NVIDIA A100-SXM4-80GB
GPU 7: NVIDIA A100-SXM4-80GB

Nvidia driver version: 510.73.08
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.2.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.2.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.2.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.2.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.2.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.2.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.2.0
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.23.1
[pip3] torch==1.11.0+cu113
[pip3] torchaudio==0.11.0+cu113
[pip3] torchinfo==1.7.2
[pip3] torchvision==0.12.0+cu113
[conda] Could not collect

PyTorch 1.12.0 disabled TF32 for matmuls by default as seen in these release notes. You can use torch.backends.cuda.matmul.allow_tf32 = True in your 2.0.0 environment to allow the usage of TF32 for matmuls again and could re-profile the workload.

Thanks!

Using torch.set_float32_matmul_precision('high') achieves following:

# torch 2.0.0, without AMP, without torch.compile(), with torch.set_float32_matmul_precision('high')
ViT-Base: 545.56 images/second

I’m not sure which one is better between torch.backends.cuda.matmul.allow_tf32 and torch.set_float32_matmul_precision('high'), but latter one seems more generous.

bro, your name is love_ptrblck and the other person’s name is ptrblck, what’s your relationship? :sweat:

I just made it up as ptrblck is like a guardian of this forum.
I sincerely appreciate his help.

Thanks!