Performance regression: torch.jit.trace() significantly slower on RTX 5090 than RTX 4060 (cu128 nightly)

When running torch.jit.trace() on a simple model using RTX 5090, the performance is significantly slower than on RTX 4060-ti — despite both running the same CUDA runtime and PyTorch version. (x2.5 slower..)

Is there any solution regarding this issue? I’m struggling with very expensive GPU…

I also set the PC’s power performance to P0 state, which is the highest performance mode.

PyTorch version: 2.9.0.dev20250707+cu128
Is debug build: False
CUDA used to build PyTorch: 12.8
ROCM used to build PyTorch: N/A

OS: Microsoft Windows 11 Pro (10.0.26100 64비트)
GCC version: Could not collect
Clang version: Could not collect
CMake version: version 3.31.2
Libc version: N/A

Python version: 3.10.18 | packaged by Anaconda, Inc. | (main, Jun 5 2025, 13:08:55) [MSC v.1929 64 bit (AMD64)] (64-bit runtime)
Python platform: Windows-10-10.0.26100-SP0
Is CUDA available: True
CUDA runtime version: 12.1.66
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 5090
Nvidia driver version: 576.02
cuDNN version: C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1\bin\cudnn_ops_train64_8.dll
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Name: AMD Ryzen 7 9800X3D 8-Core Processor
Manufacturer: AuthenticAMD
Family: 107
Architecture: 9
ProcessorType: 3
DeviceID: CPU0
CurrentClockSpeed: 4700
MaxClockSpeed: 4700
L2CacheSize: 8192
L2CacheSpeed: None
Revision: 17408

Versions of relevant libraries:
[pip3] mkl_fft==1.3.11
[pip3] mkl_random==1.2.8
[pip3] mkl-service==2.4.0
[pip3] numpy==2.2.5
[pip3] torch==2.9.0.dev20250707+cu128
[pip3] torchaudio==2.8.0.dev20250708+cu128
[pip3] torchvision==0.24.0.dev20250708+cu128
[conda] blas 1.0 mkl
[conda] intel-openmp 2023.1.0 h59b6b97_46320
[conda] mkl 2023.1.0 h6b88ed4_46358
[conda] mkl-service 2.4.0 py310h827c3e9_2
[conda] mkl_fft 1.3.11 py310h827c3e9_0
[conda] mkl_random 1.2.8 py310hc64d2fc_0
[conda] numpy 2.2.5 py310h5f75535_0
[conda] numpy-base 2.2.5 py310h23d94f8_0
[conda] tbb 2021.8.0 h59b6b97_0
[conda] torch 2.9.0.dev20250707+cu128 pypi_0 pypi
[conda] torchaudio 2.8.0.dev20250708+cu128 pypi_0 pypi
[conda] torchvision 0.24.0.dev20250708+cu128 pypi_0 pypi

The following code is an example that I made to be as close as possible to my problematic code for security reasons.

import torch
import torch.nn as nn
import time

# The following model is an example model.
class UNetLike(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU()
        )

    def forward(self, x):
        return self.conv(x)

# Example same size with my model
BATCH_SIZE = 64
H, W = 512, 512

model = UNetLike().half().cuda()
example_input = torch.randn(BATCH_SIZE, 1, H, W, dtype=torch.float16, device="cuda")

# JIT trace
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()

traced = torch.jit.trace(model, example_input)

end.record()
torch.cuda.synchronize()
print("Trace time (ms):", start.elapsed_time(end))

TorchScript is in maintenance and won’t receive any major features or updates anymore. Use torch.compile instead.

I keep reading this recommendation more often now and wanted to ask why torch.compile is recommended alot when it is not a replacement of TorchScript yet? E.g. it doesn’t have a c++ runtime, which makes it not suitable for a switch in alot of cases.

torch.compile() is not working in Windows.

torch.compile doesn’t work on Windows. I’m planning to replace my RTX 5090 with a 4090 for torch.jit.trace, because the 5090 is very unsuitable for accelerating jit.trace, and it’s a waste of money.