Slow performance when running TransformerDecoder with low batch size in fp16

Hey!

I want to understand why this snippet of code runs slow for batch sizes of 2/4/8 when using fp16, the time it takes for bs=2 on my system is 0.55 sec for no batch version and 2.8 sec for the batched on.

from torch import nn
import torch
import time
from torch.amp import autocast
device = torch.device('cuda')

hidden_dim = 1024
USE_FP16 = True
BS = 2
TOTAL_INPUT = 250
decoder_layer = nn.TransformerDecoderLayer(d_model=hidden_dim,
                                           dim_feedforward=hidden_dim*4,nhead=8,batch_first=True,)
transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=3,).to(device)

with autocast('cuda', enabled=USE_FP16):
    with torch.no_grad():
        memory = torch.rand((1, 1000, hidden_dim), device=device)
        tgt = torch.rand((1, 36, hidden_dim), device=device)
        out = transformer_decoder(tgt, memory)
        s = time.perf_counter()
        for _ in range(TOTAL_INPUT):
            out = transformer_decoder(tgt, memory)
        print(time.perf_counter() - s)


        memory = torch.rand((BS, 1000, hidden_dim), device=device)
        tgt = torch.rand((BS, 36, hidden_dim), device=device)


        s = time.perf_counter()
        for _ in range(TOTAL_INPUT // BS):
            out = transformer_decoder(tgt, memory)
        print(time.perf_counter() - s)

This is also reproduced in google colab
Thanks!

You add overhead by imposing batch_first. Namely, the model has to internally transpose dimensions and that extra computation make sthe whole thing slower.