If cudnn benchmark=True in pytorch 2.x, the first backward() takes too long

Hi pytorch guys,

I bumped into an issue that if I set torch.backends.cudnn.benchmark = True in pytorch 2.x and an almost full GPU memory-sized tensor is used, the first backward() function takes too long. It is okay when I use pytorch 1.x.

Below is my test code.

import time

import torch
import torch.backends
import torch.backends.cudnn
from torchvision.models import resnet18

torch.backends.cudnn.benchmark = True

print(f'torch: {torch.__version__}')
print(f'{torch.backends.cudnn.benchmark=}')

device = torch.device('cuda:0')
model = resnet18().to(device)

# Loss function.
criterion = torch.nn.CrossEntropyLoss()

model.train()

image = torch.rand((5, 3, 2048, 2048))
labels = torch.randint(0, 1, (5,))

image = image.to(device)
labels = labels.to(device)

# Forward pass.
outputs = model(image)

# Calculate the loss.
loss = criterion(outputs, labels)

# Backpropagation
start = time.time()
loss.backward()
print(f'Backward time: {time.time() - start}')

My GPU is RTX 3080 and bigger than (5, 3, 2048, 2048) tensor generates OOM.

pytorch 1.13.1, benchmark false

torch: 1.13.1+cu117
torch.backends.cudnn.benchmark=False
Backward time: 0.6044294834136963

pytorch 1.13.1, benchmark true

torch: 1.13.1+cu117
torch.backends.cudnn.benchmark=True
Backward time: 1.2438836097717285

pytorch 2.4.1, benchmark false

torch: 2.4.1+cu118
torch.backends.cudnn.benchmark=False
Backward time: 0.3900582790374756

pytorch 2.4.1, benchmark true

torch: 2.4.1+cu118
torch.backends.cudnn.benchmark=True
Backward time: 17.757819414138794

Is there any changes in the cudnn benchmark behavior since pytorch 2.x?

It’s expected to see a slow execution in the first iteration for each new workload as cudnn.benchmark=True will profile multiple kernels internally. In your code you are only measuring the first call and are also not synchronizing your code, so it’s unclear what the real kernel execution times are.

Hi ptrblck,

Thanks for your quick response. I added synchronization code like below and re-measured the numbers.

start = time.time()
loss.backward()
torch.cuda.synchronize()
print(f'Backward time: {time.time() - start}')

Even though I added the synchronization code, the results are almost same as before.

torch: 1.13.1+cu117
torch.backends.cudnn.benchmark=False
Backward time: 0.6931467056274414
torch: 1.13.1+cu117
torch.backends.cudnn.benchmark=True
Backward time: 1.2821173667907715
torch: 2.4.1+cu118
torch.backends.cudnn.benchmark=False
Backward time: 0.5114474296569824
torch: 2.4.1+cu118
torch.backends.cudnn.benchmark=True
Backward time: 19.250582218170166

I understand the first iteration is slow due to trying multiple kernels. But the 2.4.1 case is way too slow. Could you explain the reason why the speed is extremely slow?

Depending on the use case, a very slow kernel could be profiled thus slowing down the first iteration. After this initial warmup, the slow kernels won’t be called anymore (unless the profiling is triggered for another use case again, which could pick a slow kernel too during the profiling stage).