Autocast and model size

I have a 3090 and 4090. I implemented autocast with float16 for forward on a CNN model with a fc layer. With a batch of 16 I fill up the 24GB of gpu memory. The autocast mode runs a little slower than without it. I didn’t do the GradScalar code on the loss.backward due to not knowing at the time that I should. My test epoch took 1min 40 vs 1min 30 with 10k rows of input data.

Should I look to expanding the number of rows in my test so the epoch runs for an hour or more as in a real world example? Would that result in an autocast that is faster on a bigger data set?
Or
Did the lack of the GradScalar code on the loss.backward make my test slower?
Or
Other suggestions to use float16 autocast to improve speed?

The GradScaler is needed to avoid vanishing gradients when float16 is used in autocast and should not give you a speedup.
Could you post your model definition so that we can check if autocast is properly used and why you are not seeing a speedup, please?

I am using this model: character-based-cnn/model.py at master · ahmedbesbes/character-based-cnn · GitHub

import json
import torch
import torch.nn as nn

class CharacterLevelCNN(nn.Module):
def init(self, args, number_of_classes):
super(CharacterLevelCNN, self).init()

    # define conv layers

    self.dropout_input = nn.Dropout2d(args.dropout_input)

    self.conv1 = nn.Sequential(
        nn.Conv1d(
            args.number_of_characters + len(args.extra_characters),
            256,
            kernel_size=7,
            padding=0,
        ),
        nn.ReLU(),
        nn.MaxPool1d(3),
    )

    self.conv2 = nn.Sequential(
        nn.Conv1d(256, 256, kernel_size=7, padding=0), nn.ReLU(), nn.MaxPool1d(3)
    )

    self.conv3 = nn.Sequential(
        nn.Conv1d(256, 256, kernel_size=3, padding=0), nn.ReLU()
    )

    self.conv4 = nn.Sequential(
        nn.Conv1d(256, 256, kernel_size=3, padding=0), nn.ReLU()
    )

    self.conv5 = nn.Sequential(
        nn.Conv1d(256, 256, kernel_size=3, padding=0), nn.ReLU()
    )

    self.conv6 = nn.Sequential(
        nn.Conv1d(256, 256, kernel_size=3, padding=0), nn.ReLU(), nn.MaxPool1d(3)
    )

    # compute the  output shape after forwarding an input to the conv layers

    input_shape = (
        128,
        args.max_length,
        args.number_of_characters + len(args.extra_characters),
    )
    self.output_dimension = self._get_conv_output(input_shape)

    # define linear layers

    self.fc1 = nn.Sequential(
        nn.Linear(self.output_dimension, 1024), nn.ReLU(), nn.Dropout(0.5)
    )

    self.fc2 = nn.Sequential(nn.Linear(1024, 1024), nn.ReLU(), nn.Dropout(0.5))

    self.fc3 = nn.Linear(1024, number_of_classes)

    # initialize weights

    self._create_weights()

# utility private functions

def _create_weights(self, mean=0.0, std=0.05):
    for module in self.modules():
        if isinstance(module, nn.Conv1d) or isinstance(module, nn.Linear):
            module.weight.data.normal_(mean, std)

def _get_conv_output(self, shape):
    x = torch.rand(shape)
    x = x.transpose(1, 2)
    x = self.conv1(x)
    x = self.conv2(x)
    x = self.conv3(x)
    x = self.conv4(x)
    x = self.conv5(x)
    x = self.conv6(x)
    x = x.view(x.size(0), -1)
    output_dimension = x.size(1)
    return output_dimension

# forward

def forward(self, x):
    x = self.dropout_input(x)
    x = x.transpose(1, 2)
    x = self.conv1(x)
    x = self.conv2(x)
    x = self.conv3(x)
    x = self.conv4(x)
    x = self.conv5(x)
    x = self.conv6(x)
    x = x.view(x.size(0), -1)
    x = self.fc1(x)
    x = self.fc2(x)
    x = self.fc3(x)
    return x

Any idea why this model doesn’t benefit from Autocast mixed precision? @ptrblck

@ptrblck . I’m wondering if I modify my data model to use float16 implicitly would it work? Below I would change the dtype = float32 to float16. Would I also need the Autocast in this instance ?

class MyDataset(Dataset):
def init(self, texts, labels, args):
self.texts = texts
self.labels = labels
self.length = len(self.texts)

    self.vocabulary = args.alphabet + args.extra_characters
    self.number_of_characters = args.number_of_characters + len(
        args.extra_characters
    )
    self.max_length = args.max_length
    self.preprocessing_steps = args.steps
    self.identity_mat = np.identity(self.number_of_characters)

def __len__(self):
    return self.length

def __getitem__(self, index):
    raw_text = self.texts[index]

    data = np.array(
        [
            self.identity_mat[self.vocabulary.index(i)]
            for i in list(raw_text)[::-1]
            if i in self.vocabulary
        ],
        dtype=np.float32,
    )
    if len(data) > self.max_length:
        data = data[: self.max_length]
    elif 0 < len(data) < self.max_length:
        data = np.concatenate(
            (
                data,
                np.zeros(
                    (self.max_length - len(data), self.number_of_characters),
                    dtype=np.float32,
                ),
            )
        )
    elif len(data) == 0:
        data = np.zeros(
            (self.max_length, self.number_of_characters), dtype=np.float32
        )

    label = self.labels[index]
    data = torch.Tensor(data)

    return data, label

I realized I don’t need to force it to float16, all I have to do is assert that it’s float16 after Autocast statement, and it is. So it’s just a matter of PyTorch not making this model smaller and faster but I can’t tell why.

Sorry for the late replay as this topic fell through the cracks. Let me check your code again and see if I can achieve a speedup using amp.

Using this code I see a speedup in amp, but not a huge one:

import torch
import torch.nn as nn
import time


class CharacterLevelCNN(nn.Module):
    def __init__(self, number_of_classes):
        super(CharacterLevelCNN, self).__init__()

        # define conv layers

        self.dropout_input = nn.Dropout2d(0.5)

        self.conv1 = nn.Sequential(
            nn.Conv1d(
                128,
                256,
                kernel_size=7,
                padding=0,
            ),
            nn.ReLU(),
            nn.MaxPool1d(3),
        )

        self.conv2 = nn.Sequential(
            nn.Conv1d(256, 256, kernel_size=7, padding=0), nn.ReLU(), nn.MaxPool1d(3)
        )

        self.conv3 = nn.Sequential(
            nn.Conv1d(256, 256, kernel_size=3, padding=0), nn.ReLU()
        )

        self.conv4 = nn.Sequential(
            nn.Conv1d(256, 256, kernel_size=3, padding=0), nn.ReLU()
        )

        self.conv5 = nn.Sequential(
            nn.Conv1d(256, 256, kernel_size=3, padding=0), nn.ReLU()
        )

        self.conv6 = nn.Sequential(
            nn.Conv1d(256, 256, kernel_size=3, padding=0), nn.ReLU(), nn.MaxPool1d(3)
        )

        # compute the  output shape after forwarding an input to the conv layers

        input_shape = (
            128,
            128,
            128,
        )
        self.output_dimension = self._get_conv_output(input_shape)

        # define linear layers

        self.fc1 = nn.Sequential(
            nn.Linear(self.output_dimension, 1024), nn.ReLU(), nn.Dropout(0.5)
        )

        self.fc2 = nn.Sequential(nn.Linear(1024, 1024), nn.ReLU(), nn.Dropout(0.5))

        self.fc3 = nn.Linear(1024, number_of_classes)

        # initialize weights

        self._create_weights()

    # utility private functions

    def _create_weights(self, mean=0.0, std=0.05):
        for module in self.modules():
            if isinstance(module, nn.Conv1d) or isinstance(module, nn.Linear):
                module.weight.data.normal_(mean, std)

    def _get_conv_output(self, shape):
        x = torch.rand(shape)
        x = x.transpose(1, 2)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = self.conv6(x)
        x = x.view(x.size(0), -1)
        output_dimension = x.size(1)
        return output_dimension

    # forward

    def forward(self, x):
        x = self.dropout_input(x)
        x = x.transpose(1, 2)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = self.conv6(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x


def profile(model, x, amp, benchmark):
    torch.backends.cudnn.benchmark = benchmark
    scaler = torch.cuda.amp.GradScaler(enabled=amp)
    # warmup
    for _ in range(10):
        with torch.cuda.amp.autocast(enabled=amp):
            out = model(x)
            scaler.scale(out.mean()).backward()
    
    nb_iters = 100
    torch.cuda.synchronize()
    t0 = time.perf_counter()
    for _ in range(nb_iters):
        with torch.cuda.amp.autocast(enabled=amp):
            out = model(x)
            scaler.scale(out.mean()).backward()
    torch.cuda.synchronize()
    t1 = time.perf_counter()
    print("amp enabled: {}, benchmark: {}, {}iter/s".format(
        amp, benchmark, nb_iters/(t1-t0)))


device = "cuda"
model = CharacterLevelCNN(10).to(device)
x = torch.randn(128, 128, 128, device=device)

profile(model, x, amp=False, benchmark=False)
profile(model, x, amp=False, benchmark=True)
profile(model, x, amp=True, benchmark=False)
profile(model, x, amp=True, benchmark=True)

# amp enabled: False, benchmark: False, 355.158750122739iter/s
# amp enabled: False, benchmark: True, 370.2653699860738iter/s
# amp enabled: True, benchmark: False, 404.79240530359124iter/s
# amp enabled: True, benchmark: True, 419.78821713531175iter/s

You should also note that cuDNN is allowed to use TF32, which uses TensorCores for convs if possible, and is already speeding up your model in float32. Disabling it shows the true FP32 performance:

torch.backends.cudnn.allow_tf32 = False
profile(model, x, amp=False, benchmark=False)
profile(model, x, amp=False, benchmark=True)

# amp enabled: False, benchmark: False, 191.57821257885942iter/s
# amp enabled: False, benchmark: True, 193.78960927342436iter/s

Could you check my execute my code and see how the model performs on your system?

Thank you. I will verify the results on my end. This was very helpful.