2x faster training than default precision when model is just 3 linear layer, but when the model is just 1 conv layer, mixed precision training is slower than default precision

OS: ubuntu 16.04
Cuda: 10.2
pytorch: 1.8.0
I ran the demo https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html#a-simple-network.
default precision training time is 1.760 sec and mixed precision training time is 0.833. 2x speed up. But when I use torch.nn.Conv2d, amp training is slower than default model. please tell me why?

Could you share the code used to profile these workloads as well as more information about your setup, please? In particular information about the used GPU would be needed.

import torch
import torch.nn as nn
from torch.cuda.amp import autocast, GradScaler
from torchvision.models import resnet18
from torch.backends.cudnn import benchmark
import time
import gc

start_time = None
torch.backends.cudnn.benchmark = True

batch_size = 512
in_size = 4096
out_size = 4096
num_layers = 3
num_batches = 50
epochs = 3
data = [torch.randn(batch_size, in_size, device='cuda') for _ in range(num_batches)]
targets = [torch.randn(batch_size, out_size, device='cuda') for _ in range(num_batches)]

def start_timer():
    global start_time
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_max_memory_allocated()
    torch.cuda.synchronize()
    start_time = time.time()


def end_timer_and_print(local_msg):
    torch.cuda.synchronize()
    end_time = time.time()
    print('\n' + local_msg)
    print('Total execution time = {:.3f} sec'.format(end_time -start_time))
    print('Max memory used by tensors = {} bytes'.format(torch.cuda.max_memory_allocated()))


def make_model(in_size, out_size, num_layers):
    layers = []
    for _ in range(num_layers - 1):
        layers.append(nn.Linear(in_size, in_size))
        layers.append(nn.ReLU())
    layers.append(nn.Linear(in_size, out_size))
    return nn.Sequential(*tuple(layers)).cuda()

def train_default():
    net = make_model(in_size, out_size, num_layers)
    opt = torch.optim.SGD(net.parameters(), lr=0.001)
    loss_fn = torch.nn.MSELoss().cuda()
    net.train()
    start_timer()
    for epoch in range(epochs):
        for input, target in zip(data, targets):
            output = net(input)
            loss = loss_fn(output, target)
            loss.backward()
            opt.step()
            opt.zero_grad()
    end_timer_and_print('Default precision:')

def train_mixed():
    net = make_model(in_size, out_size, num_layers)
    opt = torch.optim.SGD(net.parameters(), lr=0.001)
    loss_fn = torch.nn.MSELoss().cuda()
    net.train()
    scaler = GradScaler()
    start_timer()
    for epoch in range(epochs):
        for input, target in zip(data, targets):
            with autocast():
                output = net(input)
                loss = loss_fn(output, input)
            scaler.scale(loss).backward()
            scaler.step(opt)
            scaler.update()
            opt.zero_grad()
    end_timer_and_print('Mixed Precision:')

class SingleConv(nn.Module):
    def __init__(self):
        super(SingleConv, self).__init__()
        self.conv = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)

    def forward(self, x):
        return self.conv(x)

def train_single_conv():
    default_net = SingleConv()
    default_net.cuda()
    default_net.train()
    batches = 5
    default_opt = torch.optim.SGD(default_net.parameters(), lr=0.0001)
    loss_fn = torch.nn.MSELoss().cuda()
    data = [torch.randn(64, 3, 224, 224, device='cuda') for _ in range(batches)]
    tar = [torch.randn(64, 64, 112, 112, device='cuda') for _ in range(batches)]
    start_timer()
    for _ in range(5):
        for input, target in zip(data, tar):
            out = default_net(input)
            loss = loss_fn(out, target)
            loss.backward()
            default_opt.step()
            default_opt.zero_grad()
    end_timer_and_print('Default 1 Conv train:')

    mixed_net = SingleConv()
    mixed_net.cuda()
    mixed_net.train()
    loss_fn = nn.MSELoss().cuda()
    scaler = GradScaler()
    mixed_opt = torch.optim.SGD(mixed_net.parameters(), lr=0.0001)
    start_timer()
    for _ in range(5):
        for input, target in zip(data, tar):
            with autocast():
                out = mixed_net(input)
                loss = loss_fn(out, target)
            scaler.scale(loss).backward()
            scaler.step(mixed_opt)
            scaler.update()
            mixed_opt.zero_grad()
    end_timer_and_print('Mixed 1 Conv train:')

the output is:
Default precision:
Total execution time = 1.760 sec
Max memory used by tensors = 1367458816 bytes

Mixed Precision:
Total execution time = 0.833 sec
Max memory used by tensors = 1401006592 bytes

Default 1 Conv train:
Total execution time = 0.090 sec
Max memory used by tensors = 2470273536 bytes

Mixed 1 Conv train:
Total execution time = 0.127 sec
Max memory used by tensors = 2644665344 bytes

And I used resnet18 to test. amp is faster, the memory usage is also lower.

def train_resnet18():
    x = torch.randn(64, 3, 224, 224, device='cuda')
    y = torch.randn(64, 1000, device='cuda')
    default_model = resnet18(pretrained=False)
    default_model.cuda()
    default_model.train()
    opt = torch.optim.SGD(default_model.parameters(), lr=1e-3)
    loss_fn = torch.nn.MSELoss().cuda()
    start_timer()
    for _ in range(500):
        y_pred = default_model(x)
        loss = loss_fn(y, y_pred)

        loss.backward()
        opt.step()
        opt.zero_grad()
    end_timer_and_print('default resnet18 training')

    mixed_model = resnet18(pretrained=False)
    mixed_model.cuda()
    mixed_model.train()
    scaler = GradScaler()
    mix_opt = torch.optim.SGD(mixed_model.parameters(), lr=1e-3)
    loss_fn = torch.nn.MSELoss().cuda()
    start_timer()
    for _ in range(500):
        with autocast():
            pred = mixed_model(x)
            loss = loss_fn(y, pred)
        scaler.scale(loss).backward()
        scaler.step(mix_opt)
        scaler.update()
        mix_opt.zero_grad()
    end_timer_and_print('mixed resnet18 training')

default resnet18 training
Total execution time = 27.581 sec
Max memory used by tensors = 3741790208 bytes

mixed resnet18 training
Total execution time = 15.231 sec
Max memory used by tensors = 3247211520 bytes

Thanks, I still don’t know the device you were using.

sorry, device is v100, 32g

Thanks for the update!
I’ve added some missing warmup iterations, since you are using cuDNN’s benchmark mode, so the first iteration will see a slowdown due to the internal profiling.
I can reproduce the slower kernel execution time in FP16 using the default memory format, but see a speedup using memory_format=torch.channels_last, which will avoid the internal permutations via:

cudnn::ops::nchwToNhwcKernel
cudnn::ops::nhwcToNchwKernel

Used code:

import torch
import torch.nn as nn
import time

torch.backends.cudnn.benchmark = True

conv = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False).cuda()
x = torch.randn(64, 3, 224, 224).cuda()


# warmup
for _ in range(10):
    out = conv(x)
    out.backward(torch.rand_like(out))
grad = torch.rand_like(out)

nb_iters = 100
torch.cuda.synchronize()
t0 = time.perf_counter()
for _ in range(nb_iters):
    out = conv(x)
    out.backward(grad)
torch.cuda.synchronize()
t1 = time.perf_counter()
print('{} s/iter'.format((t1 - t0)/nb_iters))


# warmup
conv.to(memory_format=torch.channels_last)
x = x.to(memory_format=torch.channels_last)
for _ in range(10):
    with torch.cuda.amp.autocast():
        out = conv(x)
        out.backward(torch.rand_like(out))
grad = torch.rand_like(out)

torch.cuda.synchronize()
t0 = time.perf_counter()
for _ in range(nb_iters):
    with torch.cuda.amp.autocast():
        out = conv(x)
        out.backward(grad)
torch.cuda.synchronize()
t1 = time.perf_counter()
print('{} s/iter'.format((t1 - t0)/nb_iters))

Default memory format:

0.0012196038290858268 s/iter
0.0014858345501124858 s/iter

Channels-last:

0.0012191115505993367 s/iter
0.0008443104848265648 s/iter

Thank you for the reply. I’ll try this later.
And I can not understand clearly about your reply “Grouped convolutions might not trigger the FP16 path and thus might not use TensorCores.here”. what does the “Grouped convolutions” mean. Could you please give me some example?

By grouped convolution I’m meaning a conv layer using groups!=1 such as in the linked post:

out = F.conv2d(input, self.weight, padding=self.padding, groups=4)

Thanks you so much. this helps me a lot.