Performance regression in torch 2.0 with deterministic algorithms

Hi,

I’ve noticed a significant performance slowdown in torch 2.0 when enabling determinism.

Here is a simple example using the diffusers library:

import os
import sys
from datetime import timedelta
import time
import torch
from diffusers import UNet2DModel

import torch

torch.backends.cuda.matmul.allow_tf32 = True

def set_deterministic(mode=True):
    torch.backends.cudnn.benchmark = not mode
    torch.backends.cudnn.deterministic = mode
    torch.use_deterministic_algorithms(mode, warn_only=True)
    if mode:
        os.putenv("CUBLAS_WORKSPACE_CONFIG", ":4096:8")
    else:
        os.unsetenv("CUBLAS_WORKSPACE_CONFIG")
    print(f"Deterministic: {mode}")

def go():
    scaler = torch.cuda.amp.GradScaler()
    batch_size = 8
    channels = 3
    sample_size = 64
    n = 20
    device = torch.device("cuda")
    model = UNet2DModel(
        sample_size=sample_size,
        in_channels=channels, out_channels=channels,
        layers_per_block=2,
        block_out_channels=(128, 128, 256, 256, 512, 512),
        norm_num_groups=32,
        down_block_types=("DownBlock2D", "DownBlock2D", "DownBlock2D", "DownBlock2D", "AttnDownBlock2D", "DownBlock2D"),
        up_block_types=("UpBlock2D", "AttnUpBlock2D", "UpBlock2D", "UpBlock2D", "UpBlock2D", "UpBlock2D"))

    model = model.to(device=device)
    model.train()
    start = time.time()
    rng = torch.Generator(device="cuda").manual_seed(0)
    for step in range(n):
        input = torch.randn((batch_size, channels, sample_size, sample_size), device=device)
        target = torch.randn((batch_size, channels, sample_size, sample_size), device=device)
        bs = input.shape[0]
        timestep = torch.randint(0, 1000, (bs,), generator=rng, device=device)
        with torch.autocast(device_type='cuda', dtype=torch.float16):
            output = model(input, timestep=timestep)
            loss = torch.nn.functional.mse_loss(output.sample, target, reduction="none").mean()
        scaler.scale(loss).backward()

    duration = timedelta(seconds=time.time() - start)

    print(f"Train duration {duration} ({n/duration.total_seconds():.02f} it/s)")

    model = model.to(dtype=torch.float16)
    model.eval()
    start = time.time()
    with torch.no_grad():
        for i in range(n):
            input = torch.randn((batch_size, channels, sample_size, sample_size), device=model.device, dtype=model.dtype)
            timestep = torch.randint(0, 1000, (batch_size,), device=model.device, dtype=model.dtype)
            output = model(input, timestep=timestep)

    duration = timedelta(seconds=time.time() - start)

    print(f"Eval duration {duration} ({n/duration.total_seconds():.02f} it/s)")

def main(mode):
    print(f"Torch version: {torch.__version__}")
    set_deterministic(mode)
    go()

if __name__ == "__main__":
    main(bool(int(sys.argv[1])))

With pytorch-1.13, performance is roughly equal whether determinism is enabled or not:

Torch version: 1.13.0a0+git49444c3
Deterministic: False
Train duration 0:00:02.445595 (8.18 it/s)
Eval duration 0:00:00.488221 (40.97 it/s)
Torch version: 1.13.0a0+git49444c3
Deterministic: True
Train duration 0:00:02.433920 (8.22 it/s)
Eval duration 0:00:00.484679 (41.26 it/s)

But with pytorch-2.0, performance degrades by 2-4x (or even worse on more complex cases):

Torch version: 2.0.0a0+gite9ebda2
Deterministic: False
Train duration 0:00:02.245691 (8.91 it/s)
Eval duration 0:00:00.477144 (41.92 it/s)
Torch version: 2.0.0a0+gite9ebda2
Deterministic: True
Train duration 0:00:05.969440 (3.35 it/s)
Eval duration 0:00:01.809939 (11.05 it/s)

The difference also happens without using mixed precision, but it is especially visible when using it. GPU usage goes from 100% in non-deterministic mode to <50% in deterministic mode, making me think some operations might be running on the CPU.

Given that determinism did not degrade performance in 1.13, I would expect similar results in 2.0. Did something change in 2.0 to explain this result? Does determinism need to be enabled differently?

This is using cuda-11.2.2 and libcudnn-8.9.4 in both cases.

Thanks,
A.

CUDA operations are executed asynchronously so you would need to synchronize the code before starting and stopping the timers. Your current profile is thus invalid.

That’s a wrong expectation since forcing deterministic algorithms is expected to potentially downgrade the performance.

I’ve added calls to torch.cuda.synchronize(device) before my time.time() calls and that made no difference to the timings.

I understand that there are no guarantees, but given that 1.13 showed no degradation, and torch 2.0 is branded as “faster” and being the same as 1.13 but with (optional) compilation on top, it would be great to have at least some explanation as to where the degradation is coming from.

1 Like