Hi,
I’ve noticed a significant performance slowdown in torch 2.0 when enabling determinism.
Here is a simple example using the diffusers library:
import os
import sys
from datetime import timedelta
import time
import torch
from diffusers import UNet2DModel
import torch
torch.backends.cuda.matmul.allow_tf32 = True
def set_deterministic(mode=True):
torch.backends.cudnn.benchmark = not mode
torch.backends.cudnn.deterministic = mode
torch.use_deterministic_algorithms(mode, warn_only=True)
if mode:
os.putenv("CUBLAS_WORKSPACE_CONFIG", ":4096:8")
else:
os.unsetenv("CUBLAS_WORKSPACE_CONFIG")
print(f"Deterministic: {mode}")
def go():
scaler = torch.cuda.amp.GradScaler()
batch_size = 8
channels = 3
sample_size = 64
n = 20
device = torch.device("cuda")
model = UNet2DModel(
sample_size=sample_size,
in_channels=channels, out_channels=channels,
layers_per_block=2,
block_out_channels=(128, 128, 256, 256, 512, 512),
norm_num_groups=32,
down_block_types=("DownBlock2D", "DownBlock2D", "DownBlock2D", "DownBlock2D", "AttnDownBlock2D", "DownBlock2D"),
up_block_types=("UpBlock2D", "AttnUpBlock2D", "UpBlock2D", "UpBlock2D", "UpBlock2D", "UpBlock2D"))
model = model.to(device=device)
model.train()
start = time.time()
rng = torch.Generator(device="cuda").manual_seed(0)
for step in range(n):
input = torch.randn((batch_size, channels, sample_size, sample_size), device=device)
target = torch.randn((batch_size, channels, sample_size, sample_size), device=device)
bs = input.shape[0]
timestep = torch.randint(0, 1000, (bs,), generator=rng, device=device)
with torch.autocast(device_type='cuda', dtype=torch.float16):
output = model(input, timestep=timestep)
loss = torch.nn.functional.mse_loss(output.sample, target, reduction="none").mean()
scaler.scale(loss).backward()
duration = timedelta(seconds=time.time() - start)
print(f"Train duration {duration} ({n/duration.total_seconds():.02f} it/s)")
model = model.to(dtype=torch.float16)
model.eval()
start = time.time()
with torch.no_grad():
for i in range(n):
input = torch.randn((batch_size, channels, sample_size, sample_size), device=model.device, dtype=model.dtype)
timestep = torch.randint(0, 1000, (batch_size,), device=model.device, dtype=model.dtype)
output = model(input, timestep=timestep)
duration = timedelta(seconds=time.time() - start)
print(f"Eval duration {duration} ({n/duration.total_seconds():.02f} it/s)")
def main(mode):
print(f"Torch version: {torch.__version__}")
set_deterministic(mode)
go()
if __name__ == "__main__":
main(bool(int(sys.argv[1])))
With pytorch-1.13, performance is roughly equal whether determinism is enabled or not:
Torch version: 1.13.0a0+git49444c3
Deterministic: False
Train duration 0:00:02.445595 (8.18 it/s)
Eval duration 0:00:00.488221 (40.97 it/s)
Torch version: 1.13.0a0+git49444c3
Deterministic: True
Train duration 0:00:02.433920 (8.22 it/s)
Eval duration 0:00:00.484679 (41.26 it/s)
But with pytorch-2.0, performance degrades by 2-4x (or even worse on more complex cases):
Torch version: 2.0.0a0+gite9ebda2
Deterministic: False
Train duration 0:00:02.245691 (8.91 it/s)
Eval duration 0:00:00.477144 (41.92 it/s)
Torch version: 2.0.0a0+gite9ebda2
Deterministic: True
Train duration 0:00:05.969440 (3.35 it/s)
Eval duration 0:00:01.809939 (11.05 it/s)
The difference also happens without using mixed precision, but it is especially visible when using it. GPU usage goes from 100% in non-deterministic mode to <50% in deterministic mode, making me think some operations might be running on the CPU.
Given that determinism did not degrade performance in 1.13, I would expect similar results in 2.0. Did something change in 2.0 to explain this result? Does determinism need to be enabled differently?
This is using cuda-11.2.2 and libcudnn-8.9.4 in both cases.
Thanks,
A.