Cuda.amp slower than TF32 on NVIDIA A100?

Hi, I tested Vit-large on NVIDIA A100 GPU with Pytorch1.12, the results are listed below:

  1. TF32 enable
torch.backends.cuda.matmul.allow_tf32=True
torch.backends.cudnn.allow_tf32=True

time used per iter averaged is 1s.

  1. autocast
    with autocast(enabled=amp_enable):
        pred = model(images)
        loss = loss_fn(pred, labels)
    optimizer.zero_grad()
    scaler.scale(loss).backward()
    nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    scaler.step(optimizer)
    scaler.update()

time used per iter averaged is 1.4s.

Accoring to nvidia docs, autocast default uses Fp16, which is twice as faster than TF32 on A100, so I expect autocast to be faster than TF32, but the tested results are the opposite.

Is this resonable?

Thanks.

Could you post a minimal, executable code snippet which would reproduce these results and also show your profiling?

Dear @ptrblck Can you look at my problem . I am new to pytorch

Hi @ptrblck ,thanks. Below is the snippet:

import torch
from timm.models.vision_transformer import VisionTransformer
from timm.models import create_model
import time
from torch.cuda.amp import GradScaler, autocast

def train(model, input, amp_enable=False):
    torch.cuda.synchronize()
    time_start=time.time()
    for i in range(0, 5):
        with autocast(enabled=amp_enable):
            out=model(input)
        out.sum().backward()

    torch.cuda.synchronize()
    print(f'time used: {time.time()-time_start}')

if __name__=='__main__':

    model=create_model('vit_large_patch16_384',
                       pretrained=False,
                       num_classes=None,
                       drop_rate=0,
                       drop_path_rate=0.3)
    model.cuda().train()
    input=torch.rand(32,3,384,384).cuda()

    # warmup, ignore
    train(model, input)

    print('----train with fp32----')
    train(model, input)
    print('----train with autocast----')
    train(model, input, amp_enable=True)
    print('----train with tf32----')
    torch.backends.cuda.matmul.allow_tf32=True
    torch.backends.cudnn.allow_tf32=True
    train(model, input)

You will have to install timm (GitHub - rwightman/pytorch-image-models: PyTorch image models, scripts, pretrained weights -- ResNet, ResNeXT, EfficientNet, EfficientNetV2, NFNet, Vision Transformer, MixNet, MobileNet-V3/V2, RegNet, DPN, CSPNet, and more) first.

I run this code on Pytorch1.12 on A100-80GB GPU and I got:

time used: 17.40752673149109
----train with fp32----
time used: 11.999516010284424
----train with autocast----
time used: 4.917150497436523
----train with tf32----
time used: 3.436387538909912

Result shows that TF32 is faster than autocast. Is there any mistakes in my code?

Thanks for sharing it! I cannot reproduce the issue using the torch==1.12.1+cu116 binaries and get:

time used: 13.615819454193115
----train with fp32----
time used: 11.712484359741211
----train with autocast----
time used: 2.4270553588867188
----train with tf32----
time used: 3.359715700149536

on an A100 80GB GPU, which shows a speedup using AMP.

Thanks! It seems that we get close result on TF32 and fp32. I will test this on different software environments.

We tested on Pytorch1.12.1+cu113 and TF32 is still faster than autocast.
And then we tested on pytorch1.12.1+cu116 and we can reproduce your result (TF32 slower than autocast).
So I guess the difference comes from cudatoolkit.

Yes, most likely from CUDA math libs (cuDNN, cublas etc.).