When running torch.jit.trace() on a simple model using RTX 5090, the performance is significantly slower than on RTX 4060-ti — despite both running the same CUDA runtime and PyTorch version. (x2.5 slower..)
Is there any solution regarding this issue? I’m struggling with very expensive GPU…
I also set the PC’s power performance to P0 state, which is the highest performance mode.
PyTorch version: 2.9.0.dev20250707+cu128
Is debug build: False
CUDA used to build PyTorch: 12.8
ROCM used to build PyTorch: N/A
OS: Microsoft Windows 11 Pro (10.0.26100 64비트)
GCC version: Could not collect
Clang version: Could not collect
CMake version: version 3.31.2
Libc version: N/A
Python version: 3.10.18 | packaged by Anaconda, Inc. | (main, Jun 5 2025, 13:08:55) [MSC v.1929 64 bit (AMD64)] (64-bit runtime)
Python platform: Windows-10-10.0.26100-SP0
Is CUDA available: True
CUDA runtime version: 12.1.66
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 5090
Nvidia driver version: 576.02
cuDNN version: C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1\bin\cudnn_ops_train64_8.dll
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Name: AMD Ryzen 7 9800X3D 8-Core Processor
Manufacturer: AuthenticAMD
Family: 107
Architecture: 9
ProcessorType: 3
DeviceID: CPU0
CurrentClockSpeed: 4700
MaxClockSpeed: 4700
L2CacheSize: 8192
L2CacheSpeed: None
Revision: 17408
Versions of relevant libraries:
[pip3] mkl_fft==1.3.11
[pip3] mkl_random==1.2.8
[pip3] mkl-service==2.4.0
[pip3] numpy==2.2.5
[pip3] torch==2.9.0.dev20250707+cu128
[pip3] torchaudio==2.8.0.dev20250708+cu128
[pip3] torchvision==0.24.0.dev20250708+cu128
[conda] blas 1.0 mkl
[conda] intel-openmp 2023.1.0 h59b6b97_46320
[conda] mkl 2023.1.0 h6b88ed4_46358
[conda] mkl-service 2.4.0 py310h827c3e9_2
[conda] mkl_fft 1.3.11 py310h827c3e9_0
[conda] mkl_random 1.2.8 py310hc64d2fc_0
[conda] numpy 2.2.5 py310h5f75535_0
[conda] numpy-base 2.2.5 py310h23d94f8_0
[conda] tbb 2021.8.0 h59b6b97_0
[conda] torch 2.9.0.dev20250707+cu128 pypi_0 pypi
[conda] torchaudio 2.8.0.dev20250708+cu128 pypi_0 pypi
[conda] torchvision 0.24.0.dev20250708+cu128 pypi_0 pypi
The following code is an example that I made to be as close as possible to my problematic code for security reasons.
import torch
import torch.nn as nn
import time
# The following model is an example model.
class UNetLike(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(1, 64, 3, padding=1),
nn.ReLU(),
nn.Conv2d(64, 64, 3, padding=1),
nn.ReLU()
)
def forward(self, x):
return self.conv(x)
# Example same size with my model
BATCH_SIZE = 64
H, W = 512, 512
model = UNetLike().half().cuda()
example_input = torch.randn(BATCH_SIZE, 1, H, W, dtype=torch.float16, device="cuda")
# JIT trace
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
traced = torch.jit.trace(model, example_input)
end.record()
torch.cuda.synchronize()
print("Trace time (ms):", start.elapsed_time(end))