FP16 and BF16 way slower than FP32 and TF32

I don’t know what I’m doing wrong, but my FP16 and BF16 bench are way slower than FP32 and TF32 modes. Here are my results with the 2 GPUs at my disposal (RTX 2060 Mobile, RTX 3090 Desktop):

Benching precision speed on a NVIDIA GeForce RTX 2060

benching FP32…
epoch 0 took 13.9146514s
epoch 1 took 11.6350846s
epoch 2 took 11.867831299999999s

benching FP16…
epoch 0 took 15.745933399999998s
epoch 1 took 16.212830699999998s
epoch 2 took 16.495791399999987s

Benching precision speed on a NVIDIA GeForce RTX 3090

benching FP32…
epoch 0 took 5.7641565s
epoch 1 took 4.0729165s
epoch 2 took 4.0790243s

benching TF32…
epoch 0 took 4.042242200000002s
epoch 1 took 4.0321663s
epoch 2 took 4.080792600000002s

benching FP16…
epoch 0 took 5.053079000000004s
epoch 1 took 5.029029299999998s
epoch 2 took 4.973819899999995s

benching BF16…
epoch 0 took 11.721234800000005s
epoch 1 took 11.542296499999999s
epoch 2 took 11.566654600000007s

And here’s the file I made to generate the bench, a simple MNIST classifier that you can easily run on your own computer (I run it with PyTorch 1.12 on Windows and the latest NVIDIA drivers):

#MNIST example inspired by https://github.com/pytorch/examples/blob/main/mnist/main.py
import timeit
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
dataset = torchvision.datasets.MNIST(root = 'data', train = True, transform = torchvision.transforms.ToTensor(), download = True)
loader = torch.utils.data.DataLoader(dataset, batch_size=64, pin_memory=True, shuffle=True)

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

device = torch.device("cuda")

def bench(mode, epochs=3):
    print("benching "+mode+"...")
    torch.backends.cuda.matmul.allow_tf32 = True if mode == 'TF32' else False
    scaler = torch.cuda.amp.GradScaler(enabled=True if mode == 'FP16' else False)

    model = Net().to(device)
    optimizer = torch.optim.Adadelta(model.parameters(), lr=1.0)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.7)
    model.train()
    for epoch in range(epochs):
        start_time = timeit.default_timer()
        model.train()
        for batch_idx, (data, target) in enumerate(loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            with torch.autocast(device_type='cuda', dtype=torch.bfloat16 if mode == 'BF16' else torch.float16, enabled=True if '16' in mode else False):
                output = model(data)
                loss = F.nll_loss(output, target)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        scheduler.step()
        torch.cuda.synchronize()
        end_time = timeit.default_timer()
        print("epoch "+str(epoch)+" took "+str(end_time-start_time)+"s")
    print("")

print("Benching precision speed on a "+torch.cuda.get_device_name(0))
print("")
bench(mode="FP32")
if torch.cuda.is_bf16_supported():
    bench(mode="TF32")
bench(mode="FP16")
if torch.cuda.is_bf16_supported():
    bench(mode="BF16")

I cannot reproduce a slowdown using a recent PyTorch build with CUDA 11.6 and see:

benching FP32...
epoch 0 took 5.307978553988505s
epoch 1 took 3.9439379789982922s
epoch 2 took 3.9074656259908807s

benching TF32...
epoch 0 took 4.101563502015779s
epoch 1 took 3.946826447005151s
epoch 2 took 4.008894422004232s

benching FP16...
epoch 0 took 4.420351507025771s
epoch 1 took 4.3506982740073s
epoch 2 took 4.440277656976832s

benching BF16...
epoch 0 took 4.17264149300172s
epoch 1 took 4.073878110008081s
epoch 2 took 4.178816699975869s

A few things for your profiling:

  • You would need to synchronize the code also before starting the timer, not only before stopping it in case some kernels are still running (e.g. transferring the weights to the GPU).
  • You are profiling a full training run including the DataLoader, transferring the data to the GPU, and are using a tiny model. Depending on your system the actual model runtime might be tiny and you might see a large overhead from the actual data loading etc. so that actual model speedups won’t be directly visible. If you want to compare different numerical precisions and their speed, I would recommend to profile the model in isolation first.
  • You model workload is small so even if lower precision dtypes give a speedup, the actual kernel launches, the dispatching etc. might be visible. If this small model is your real workload, you might want to try CUDA graphs.
  • To enable bfloat16 calculations in conv layers, set os.environ["TORCH_CUDNN_V8_API_ENABLED"] = "1" or export this env variable. However, bfloat16 should be enabled for Ampere+, so your Turing GPU might not see any benefits.

Thanks for the advices, indeed that’s a small model/dataset just for testing so I’m not surprised to not see any dramatic speed-up, yet I’m puzzled by the visible slowdown as soon as 16 bits floats are being used. that doesn’t sound right. Is autocasting adding a significant overhead that could be noticed on small models ?

The bench were done on 2 different systems, both running Windows 11 21H2, latest NV drivers, Python 3.9, PyTorch 1.12.1 with CUDA 11.6, with NVME SSD and fast CPU/RAM.
I’m a bit surprised you did not see similar results (although it’s FP16/BF16 seems also slightly slower on average on your system too), did you run the bench on Windows or Linux ? Maybe it’s Windows-specific ?

I’ve added your suggestions (additional sync+cudnn flag), which gives me different numbers, but still slower 16 bits training:

Benching precision speed on a NVIDIA GeForce RTX 2060

benching FP32...
epoch 0 took 15.4749659s
epoch 1 took 11.017322s
epoch 2 took 13.260371799999998s

benching FP16...
epoch 0 took 15.714570200000004s
epoch 1 took 15.929676400000005s
epoch 2 took 16.421480399999993s
Benching precision speed on a NVIDIA GeForce RTX 3090

benching FP32...
epoch 0 took 7.703148500000001s
epoch 1 took 3.7395274s
epoch 2 took 3.7481109999999997s

benching TF32...
epoch 0 took 3.752561700000001s
epoch 1 took 3.6327957000000026s
epoch 2 took 3.7228427999999987s

benching FP16...
epoch 0 took 4.7007553999999985s
epoch 1 took 4.666421100000001s
epoch 2 took 4.604346900000003s

benching BF16...
epoch 0 took 4.327719200000004s
epoch 1 took 4.267118599999996s
epoch 2 took 4.287737299999996s

full updated code below:

#MNIST example inspired by https://github.com/pytorch/examples/blob/main/mnist/main.py
import os
import timeit
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
dataset = torchvision.datasets.MNIST(root = 'data', train = True, transform = torchvision.transforms.ToTensor(), download = True)
loader = torch.utils.data.DataLoader(dataset, batch_size=64, pin_memory=True, shuffle=True)

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

device = torch.device("cuda")

def bench(mode, epochs=3):
    print("benching "+mode+"...")
    os.environ["TORCH_CUDNN_V8_API_ENABLED"] = "1"
    torch.backends.cuda.matmul.allow_tf32 = True if mode == 'TF32' else False
    scaler = torch.cuda.amp.GradScaler(enabled=True if mode == 'FP16' else False)

    model = Net().to(device)
    optimizer = torch.optim.Adadelta(model.parameters(), lr=1.0)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.7)
    model.train()
    for epoch in range(epochs):
        torch.cuda.synchronize()
        start_time = timeit.default_timer()
        model.train()
        for batch_idx, (data, target) in enumerate(loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            with torch.autocast(device_type='cuda', dtype=torch.bfloat16 if mode == 'BF16' else torch.float16, enabled=True if '16' in mode else False):
                output = model(data)
                loss = F.nll_loss(output, target)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        scheduler.step()
        torch.cuda.synchronize()
        end_time = timeit.default_timer()
        print("epoch "+str(epoch)+" took "+str(end_time-start_time)+"s")
    print("")

print("Benching precision speed on a "+torch.cuda.get_device_name(0))
print("")
bench(mode="FP32")
if torch.cuda.is_bf16_supported():
    bench(mode="TF32")
bench(mode="FP16")
if torch.cuda.is_bf16_supported():
    bench(mode="BF16")