Training with autocast does not improve speed performances

I have read the documentation regarding autocast and how to use it correctly. I will share my model so that you can see if there are some layers that do not respect the autocast because the speed performance remains the same.

import torch

class DCUNet(torch.nn.Module):
    def __init__(self):
        super(DCUNet, self).__init__()
        self.dcBlock1 = DCBlock(in_channels=1, out_channels=32, kernel_size=3, stride=1, apply_batchnorm=True)
        self.resPath1 = ResPath(in_channels=32, out_channels=32, residuals=4, apply_batchnorm=True)
        self.maxPool1 = torch.nn.MaxPool2d(kernel_size=(2,2))
        self.dcBlock2 = DCBlock(in_channels=32, out_channels=64, kernel_size=3, stride=1, apply_batchnorm=True)
        self.resPath2 = ResPath(in_channels=64, out_channels=64, residuals=3, apply_batchnorm=True)
        self.maxPool2 = torch.nn.MaxPool2d(kernel_size=(2,2))
        self.dcBlock3 = DCBlock(in_channels=64, out_channels=128, kernel_size=3, stride=1, apply_batchnorm=True)
        self.resPath3 = ResPath(in_channels=128, out_channels=128, residuals=2, apply_batchnorm=True)
        self.maxPool3 = torch.nn.MaxPool2d(kernel_size=(2,2))
        self.dcBlock4 = DCBlock(in_channels=128, out_channels=256, kernel_size=3, stride=1, apply_batchnorm=True)
        self.resPath4 = ResPath(in_channels=256, out_channels=256, residuals=1, apply_batchnorm=True)
        self.maxPool4 = torch.nn.MaxPool2d(kernel_size=(2,2))
        self.dcBlock5 = DCBlock(in_channels=256, out_channels=512, kernel_size=3, stride=1, apply_batchnorm=True)
        self.cTranBlock1 = torch.nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=2, stride=2)
        self.dcBlock6 = DCBlock(in_channels=512, out_channels=256, kernel_size=3, stride=1, apply_batchnorm=True)
        self.cTranBlock2 = torch.nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=2, stride=2)
        self.dcBlock7 = DCBlock(in_channels=256, out_channels=128, kernel_size=3, stride=1, apply_batchnorm=True)
        self.cTranBlock3 = torch.nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=2, stride=2)
        self.dcBlock8 = DCBlock(in_channels=128, out_channels=64, kernel_size=3, stride=1, apply_batchnorm=True)
        self.cTranBlock4 = torch.nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=2, stride=2)
        self.dcBlock9 = DCBlock(in_channels=64, out_channels=32, kernel_size=3, stride=1, apply_batchnorm=True)
        self.lastBlock = torch.nn.Conv2d(in_channels=32, out_channels=1, kernel_size=(1,1))
    def forward(self, x):
        ### -------- Encoder block --------- ###
        dc1Out = self.dcBlock1(x)
        x = self.maxPool1(dc1Out)
        res1 = self.resPath1(dc1Out)
        dc2Out = self.dcBlock2(x)
        x = self.maxPool2(dc2Out)
        res2 = self.resPath2(dc2Out)
        dc3Out = self.dcBlock3(x)
        x = self.maxPool3(dc3Out)
        res3 = self.resPath3(dc3Out)  
        dc4Out = self.dcBlock4(x)
        x = self.maxPool4(dc4Out)
        res4 = self.resPath4(dc4Out)
        ### -------- Bottleneck block --------- ###
        dc5Out = self.dcBlock5(x)
        ### -------- Decoder block --------- ###
        cTran1Out = self.cTranBlock1(dc5Out)
        merge1 = torch.cat([cTran1Out, res4], axis=1)
        dc6Out = self.dcBlock6(merge1)
        cTran2Out = self.cTranBlock2(dc6Out)
        merge2 = torch.cat([cTran2Out, res3], axis=1)
        dc7Out = self.dcBlock7(merge2)
        cTran3Out = self.cTranBlock3(dc7Out)
        merge3 = torch.cat([cTran3Out, res2], axis=1)
        dc8Out = self.dcBlock8(merge3)
        cTran4Out = self.cTranBlock4(dc8Out)
        merge4 = torch.cat([cTran4Out, res1], axis=1)
        dc9Out = self.dcBlock9(merge4)
        out = self.lastBlock(dc9Out)
        return out

class DCBlock(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, apply_batchnorm=True):
        super(DCBlock, self).__init__()
        self.apply_batchnorm = apply_batchnorm
        self.kernel_size = kernel_size
        self.stride = stride
        self.in_channels = in_channels
        self.out_channels = out_channels

        self.lx_conv1 = torch.nn.Conv2d(in_channels=in_channels, out_channels=round(out_channels/6), kernel_size=kernel_size, stride=stride, padding='same')
        self.act1_lx = torch.nn.ReLU()
        self.lx_conv2 = torch.nn.Conv2d(in_channels=round(out_channels/6), out_channels=round(out_channels/3), kernel_size=kernel_size, stride=stride, padding='same')
        self.act2_lx = torch.nn.ReLU()
        self.lx_conv3 = torch.nn.Conv2d(in_channels=round(out_channels/3), out_channels=round(out_channels/2), kernel_size=kernel_size, stride=stride, padding='same')
        self.act3_lx = torch.nn.ReLU()
        self.rx_conv1 = torch.nn.Conv2d(in_channels=in_channels, out_channels=round(out_channels/6), kernel_size=kernel_size, stride=stride, padding='same')
        self.act1_rx = torch.nn.ReLU()
        self.rx_conv2 = torch.nn.Conv2d(in_channels=round(out_channels/6), out_channels=round(out_channels/3), kernel_size=kernel_size,stride=stride, padding='same')
        self.act2_rx = torch.nn.ReLU()
        self.rx_conv3 = torch.nn.Conv2d(in_channels=round(out_channels/3), out_channels=round(out_channels/2), kernel_size=kernel_size,stride=stride, padding='same')
        self.act3_rx = torch.nn.ReLU()
        if self.apply_batchnorm:
            self.bn1 = torch.nn.BatchNorm2d(num_features=round(out_channels))
            self.bn2 = torch.nn.BatchNorm2d(num_features=round(out_channels))
            self.bn3 = torch.nn.BatchNorm2d(num_features=round(out_channels))
        self.actFinal = torch.nn.ReLU()

   

    def forward(self, inp):
        # left path
        x = self.lx_conv1(inp)
        lx_1 = self.act1_lx(x)
        x = self.lx_conv2(lx_1)
        lx_2 = self.act2_lx(x)
        x = self.lx_conv3(lx_2)
        lx_3 = self.act2_lx(x)
        # right path
        x = self.rx_conv1(inp)
        rx_1 = self.act1_rx(x)
        x = self.rx_conv2(rx_1)
        rx_2 = self.act2_rx(x)
        x = self.rx_conv3(rx_2)
        rx_3 = self.act2_rx(x)
        # concatenation
        conc_lx = torch.cat([lx_1, lx_2, lx_3], axis=1)
        if self.apply_batchnorm:
            conc_lx = self.bn1(conc_lx)
        conc_rx = torch.cat([rx_1, rx_2, rx_3], axis=1)
        if self.apply_batchnorm:
            conc_rx = self.bn2(conc_rx)
        # summing up
        x = conc_lx.add(conc_rx)
        x = self.actFinal(x)
        if self.apply_batchnorm:
            x = self.bn3(x)
        return x

class ResPath(torch.nn.Module):
    def __init__(self, in_channels, out_channels, residuals, apply_batchnorm=True):
        super(ResPath, self).__init__()
        self.apply_batchnorm = apply_batchnorm
        self.residuals = residuals
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.conv3_list = torch.nn.ModuleList()
        self.conv1_list = torch.nn.ModuleList()
        self.bn_list = torch.nn.ModuleList()
        self.conv3_list.append(torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels,kernel_size=3, padding='same'))
        self.conv3_list.append(torch.nn.ReLU())
        self.conv1_list.append(torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, padding='same'))
        self.conv1_list.append(torch.nn.ReLU())
        if self.apply_batchnorm:
            self.bn_list.append(torch.nn.BatchNorm2d(num_features=out_channels))
        for i in range(self.residuals-1):
            self.conv3_list.append(torch.nn.Conv2d(in_channels=out_channels, out_channels=out_channels,kernel_size=3, padding='same'))
            self.conv3_list.append(torch.nn.ReLU())
            self.conv1_list.append(torch.nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=1, padding='same'))
            self.conv1_list.append(torch.nn.ReLU())
            if self.apply_batchnorm:
                self.bn_list.append(torch.nn.BatchNorm2d(num_features=out_channels))
        self.act = torch.nn.ReLU()         
    def forward(self, x):       
        for i in range(self.residuals):
            conv = self.conv3_list[i](x)
            res = self.conv1_list[i](x)
            sumed = conv.add(res)
            x = self.act(sumed)
            if self.apply_batchnorm:
                x = self.bn_list[i](x)
        return x

and in the training loop:

inputs = inputs.to(device)
masks = masks.to(device)
# Zero the gradients for every batch
#optimizer.zero_grad()
# Make predictions for this batch
with torch.cuda.amp.autocast():
    outputs = model(inputs)
    #Compute the loss and its gradients
    loss = criterion(outputs.float(), masks)
scaler.scale(loss).backward()
# Adjust learning weights
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)

Do you have any suggestion of what part of code I should inspect in order to understand the problem? Thank you!

AMP: 210 seconds
SP: 221 seconds

Hi @Matt2 what GPU are you using? Speedups from mixed precision is most evident on TensorCore GPUs (like Volta, Tesla, Ampere).

I don’t know which shapes you are using but I get these results:

FP32/TF32 - 0.04413114159979159
AMP - 0.03283868491998874
AMP, channels-last - 0.0275552004398196

in s/iter on a 3090 using:

model = DCUNet().cuda()
x = torch.randn(16, 1, 224, 224).cuda()
nb_iters = 25

# warmup
for _ in range(10):
    out = model(x)

torch.cuda.synchronize()
t0 = time.perf_counter()
for _ in range(nb_iters):
    out = model(x)
torch.cuda.synchronize()
t1 = time.perf_counter()
print((t1 - t0)/nb_iters)


# warmup
model.to(memory_format=torch.channels_last)
x = x.to(memory_format=torch.channels_last)
for _ in range(10):
    with torch.cuda.amp.autocast():
        out = model(x)

torch.cuda.synchronize()
t0 = time.perf_counter()
for _ in range(nb_iters):
    with torch.cuda.amp.autocast():
        out = model(x)
torch.cuda.synchronize()
t1 = time.perf_counter()
print((t1 - t0)/nb_iters)

RTX 2080Ti, it should be ok!

Hi!, I’m using images 512x512 with batch size 8 and 1 single channel, my time is computed for one entire epoch of 1100 batches, sorry I forgot to specify this information. However I tested the AMP with the example provided in the documentation and I get a 2x speedup, so I don’t know which is the cause of the problem

I reran my code using 1.11.0+cu113 on an RTX 2080Ti and needed to disable gradient calculations globally, as I was running out of memory otherwise.
The result shows a speedup as:

FP32: 0.15081167608499527
AMP, channels-last: 0.10359331108629703

Hi, what do you mean by disabling all gradient computations? I reduced the batch as I was running out of memory too, however I didn’t get how to inspect this problem. I share a grater portion of my training code. Is it because of external .py loaded in jupyters? Or other sources of bottlenecks? How do I find the main bottleneck in my pipeline? Is it because of my custom loss function? (DiceLoss + BCELoss)

    with tqdm(dataloader, unit="batch") as tepoch:
        for inputs, masks in tepoch:
            tepoch.set_description(f"Epoch {epoch_index}")
            # Move to device
            inputs = inputs.to(device)
            masks = masks.to(device)
            # Zero the gradients for every batch
            #optimizer.zero_grad()
            # Make predictions for this batch
            with torch.cuda.amp.autocast():
                outputs = model(inputs)
                #Compute the loss and its gradients
                loss = criterion(outputs.float(), masks)
            scaler.scale(loss).backward()
            # Adjust learning weights
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)
            # Gather data and report
            curr_loss = loss.item()
            running_loss += curr_loss
            # Update epoch loss and metrics
            #precision_score = precision(masks, outputs)
            #recall_score = recall(masks, outputs)
            tepoch.set_postfix({'loss':curr_loss, 'precision':0, 'recall':0})
            batch_num += 1

I used torch.set_grad_enabled(False) as I was running out of memory otherwise using your input sizes.

You could profile your workload using the PyTorch profiler or e.g. Nsight Systems. Once you have the profiles you can check the timeline and see where the bottleneck is coming from.

Did you switch channel in the last position for some reason? For example speed-ups in computations?

Yes, my previous post compares both memory formats and shows that channels-last is faster if you are using amp. This is also described in the performance guide.