The performance gap between torch.cuda.amp and nviddia-apex

In the past I use nviddia-apex O1 mode to train my model. And the torch.cuda.amp seems to save more memory. I compared apex and torch.cuda.amp respectively for 1000 iterations, and find the performance gap on val set was 8 (mIoU for Semantic segmentation). The model with apex was 30 mIoU and with torch.cuda.amp was 22 mIoU. Is this normal ?
Here are some of the key implementations::

def __init__(self,):
    self.scaler = GradScaler()
def step(X, label)
    with autocast(enabled=True):
        _, output = self.forward(input=X)
        outputUp = F.interpolate(output, size=X.size()[2:], mode='bilinear', align_corners=True)
        loss = self.loss_source(input=outputUp, target=label)
    self.amp_backward(loss, self.BaseOpti)

And the amp_backward function was:

def amp_backward(self, loss, optim):
    self.scaler.scale(loss).backward()
    self.scaler.step(optim)
    self.scaler.update()

No, an accuracy loss is unexpected. Which PyTorch version are you using and which model?

I use pytorch1.9.0 and torchvision 0.10.0. But I realized where did I go wrong. after the amp_backward function, I call the optim.step() again. But in fact scaler.step(optim) have called the optim.step() implicitly. I’ll fix it and try again.

Thanks for you reply. It works well now. I called the optim.step() again and it resulted in the accuracy loss.

But another problem is that: on the gpu with Tensor core, both native amp and apex(O1 mode) have a larget increase of speed. But on the 1080ti (which don’t have Tensor core), the native amp is 2x slower than apex(O1 mode). When I change the apex mode to O2, it also become 2x slower. It can be confirmed that the speed of native amp and apex(O2 mode) decreases obviously due to the fp16 power bottleneck of 1080Ti. Is there any way to change the mode of native amp to remove some fp16 cast. The O1 mode works well on 1080ti(the speed didn’t improve, but the memory footprint was much lower). I’m dying to use the native amp since it requires much fewer change of the origin code.

As described in the docs you are not supposed to call optimizer.step(), so use the described approach.

I don’t know where the slowdown is coming from, but in case you are using CNNs you could trry to use torch.backends.cudnn.benchmark = True and see if it would speedup your workload again.

When enable the benchmark, it speedup a bit but still 2x slower since 1080ti don’t have Tensor core. But I wanted to achieve the effect of the Apex O1 mode : constant speed and low memory footprint. Now the autocast() is like the apex O2 mode, which is also 2x slower on the 1080ti.

You shouldn’t use apex.amp anymore as it’s deprecated and the native implementation via torch.cuda.amp is the right approach. Could you post the model definition as well as the input shapes, please?

The model was deeplabv3+ with resnet backbone. Here is the main code.
When I run this code, the output was following:

Time cost with autocast: 2511.72 ms
Time cost without autocast: 1397.27 ms
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo

class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None, use_in=False):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = BatchNorm(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               dilation=dilation, padding=dilation, bias=False)
        self.bn2 = BatchNorm(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = BatchNorm(planes*4)
        self.relu = nn.ReLU(inplace=True)
        self.in2 = nn.InstanceNorm2d(planes * 4, affine=True) if use_in else None
        self.downsample = downsample
        self.stride = stride
        self.dilation = dilation

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        if self.in2 is not None:
            out = self.in2(out)
        out = self.relu(out)

        return out

class ResNet(nn.Module):

    def __init__(self, block, layers, output_stride, BatchNorm, pretrained=False, version='resnet101', use_in=False, use_pool=False):
        self.inplanes = 64
        super(ResNet, self).__init__()
        self.version = version
        blocks = [1, 2, 4]
        if output_stride == 16:
            strides = [1, 2, 2, 1]
            dilations = [1, 1, 1, 2]
        elif output_stride == 8:
            strides = [1, 2, 1, 1]
            dilations = [1, 1, 2, 4]
        else:
            raise NotImplementedError

        # Modules
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                                bias=False)
        self.bn1 = BatchNorm(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], dilation=dilations[0], BatchNorm=BatchNorm, use_in=use_in)
        self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1], BatchNorm=BatchNorm, use_in=use_in)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2], BatchNorm=BatchNorm)
        self.layer4 = self._make_MG_unit(block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm)
        # self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) if use_pool else None
        self._init_weight()

        if pretrained:
            self._load_pretrained_model()

    def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None, use_in=False):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                BatchNorm(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, dilation, downsample, BatchNorm))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, dilation=dilation, BatchNorm=BatchNorm, use_in=False if i < blocks-1 else use_in))

        return nn.Sequential(*layers)

    def _make_MG_unit(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                BatchNorm(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, dilation=blocks[0]*dilation,
                            downsample=downsample, BatchNorm=BatchNorm))
        self.inplanes = planes * block.expansion
        for i in range(1, len(blocks)):
            layers.append(block(self.inplanes, planes, stride=1,
                                dilation=blocks[i]*dilation, BatchNorm=BatchNorm))

        return nn.Sequential(*layers)

    def forward(self, input):
        x = self.conv1(input)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        low_level_feat = x
        x = self.layer2(x)
        mid_level_feat = x = self.layer3(x)
        x = self.layer4(x)
        if self.avgpool is not None: # for byol
            return self.avgpool(x)
        return x, low_level_feat, mid_level_feat

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _load_pretrained_model(self):
        pretrain_dict = None
        if self.version == 'resnet101':
            pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet101-5d3b4d8f.pth')
        elif self.version == 'resnet50':
            pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet50-19c8e357.pth')
        model_dict = {}
        state_dict = self.state_dict()
        for k, v in pretrain_dict.items():
            if k in state_dict:
                model_dict[k] = v
        state_dict.update(model_dict)
        self.load_state_dict(state_dict)

def ResNet101(output_stride, BatchNorm, pretrained=False, use_in=False, use_pool=False):
    """Constructs a ResNet-101 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 4, 23, 3], output_stride, BatchNorm=BatchNorm, pretrained=pretrained, version='resnet101', use_in=use_in, use_pool=use_pool)
    return model

def build_backbone(backbone, output_stride, BatchNorm, use_in=False):
    return ResNet101(output_stride, BatchNorm, use_in=use_in)



class _ASPPModule(nn.Module):
    def __init__(self, inplanes, planes, kernel_size, padding, dilation, BatchNorm):
        super(_ASPPModule, self).__init__()
        self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size,
                                            stride=1, padding=padding, dilation=dilation, bias=False)
        self.bn = BatchNorm(planes)
        self.relu = nn.ReLU()

        self._init_weight()

    def forward(self, x):
        x = self.atrous_conv(x)
        x = self.bn(x)

        return self.relu(x)

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

class ASPP(nn.Module):
    def __init__(self, output_stride, BatchNorm):
        super(ASPP, self).__init__()
        inplanes = 2048
        if output_stride == 16:
            dilations = [1, 6, 12, 18]
        elif output_stride == 8:
            dilations = [1, 12, 24, 36]
        else:
            raise NotImplementedError

        out_chanels = 256
        ASPPModule = _ASPPModule

        self.aspp1 = ASPPModule(inplanes, out_chanels, 1, padding=0, dilation=dilations[0], BatchNorm=BatchNorm)
        self.aspp2 = ASPPModule(inplanes, out_chanels, 3, padding=dilations[1], dilation=dilations[1], BatchNorm=BatchNorm)
        self.aspp3 = ASPPModule(inplanes, out_chanels, 3, padding=dilations[2], dilation=dilations[2], BatchNorm=BatchNorm)
        self.aspp4 = ASPPModule(inplanes, out_chanels, 3, padding=dilations[3], dilation=dilations[3], BatchNorm=BatchNorm)

        self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
                                             nn.Conv2d(inplanes, out_chanels, 1, stride=1, bias=False),
                                             BatchNorm(out_chanels),
                                             nn.ReLU())
        self.conv1 = nn.Conv2d(out_chanels*5, out_chanels, 1, bias=False)
        self.bn1 = (BatchNorm(out_chanels))
        self.relu = nn.ReLU()
        self._init_weight()

    def forward(self, x):
        x1 = self.aspp1(x)
        x2 = self.aspp2(x)
        x3 = self.aspp3(x)
        x4 = self.aspp4(x)
        x5 = self.global_avg_pool(x)
        x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
        x = torch.cat((x1, x2, x3, x4, x5), dim=1)

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        return x 

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

def build_aspp(output_stride, BatchNorm):
    return ASPP(output_stride, BatchNorm)

class Decoder(nn.Module):
    def __init__(self, num_classes, BatchNorm):
        super(Decoder, self).__init__()

        low_level_inplanes = 256
        out_chanels = 256
        self.conv1 = nn.Conv2d(low_level_inplanes, 48, 1, bias=False)
        self.bn1 = BatchNorm(48)
        self.relu = nn.ReLU()
        self.last_conv = nn.Sequential(
                                        nn.Conv2d(out_chanels + 48, out_chanels, kernel_size=3, stride=1, padding=1, bias=False),
                                        BatchNorm(out_chanels),
                                        nn.ReLU(),
                                        nn.Conv2d(out_chanels, out_chanels, kernel_size=3, stride=1, padding=1, bias=False),
                                        BatchNorm(out_chanels),
                                        nn.ReLU(),
                                        nn.Dropout2d(0.1),
                                        nn.Conv2d(out_chanels, num_classes, kernel_size=1, stride=1))
        self._init_weight()


    def forward(self, x, low_level_feat):
        low_level_feat = self.conv1(low_level_feat)
        low_level_feat = self.bn1(low_level_feat)
        low_level_feat = self.relu(low_level_feat)

        x = F.interpolate(x, size=low_level_feat.size()[2:], mode='bilinear', align_corners=True)
        x = torch.cat((x, low_level_feat), dim=1)

        for i in range(len(self.last_conv)-1):
            x = self.last_conv[i](x)
        out = self.last_conv[-1](x)
        return x, out

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()


def build_decoder(num_classes, BatchNorm):
    return Decoder(num_classes, BatchNorm)

class DeepLabV3_plus(nn.Module):
    def __init__(self, backbone='resnetmg101', output_stride=16, num_classes=21,
                    bn_backbone='bn', bn_aspp='bn', bn_decoder='bn', use_in=False):
        super(DeepLabV3_plus, self).__init__()
        self.best_iou = 0

        bn_dict = {'bn' : nn.BatchNorm2d,}

        self.backbone = build_backbone(backbone, output_stride, bn_dict[bn_backbone], use_in)
        self.aspp = build_aspp(output_stride, bn_dict[bn_aspp])
        self.decoder = build_decoder(num_classes, bn_dict[bn_decoder])

        
    def _load_pretrained_model(self):
        if hasattr(self.backbone, '_load_pretrained_model'):
            self.backbone._load_pretrained_model()


    def forward(self, input, detach_backbone=False):
        x, low_level_feat, mid_level_feat = self.backbone(input)
        if detach_backbone:
            x, low_level_feat, mid_level_feat = x.detach(), low_level_feat.detach(), mid_level_feat.detach()
        x = self.aspp(x)
        cls_feat, out = self.decoder(x, low_level_feat)
        return cls_feat, out

def get_deeplab_v3_plus(backbone, num_classes, bn_backbone='bn', bn_aspp='bn', bn_decoder='bn', use_in=False):
    return DeepLabV3_plus(
        backbone=backbone,
        output_stride=16,
        num_classes=num_classes,
        bn_backbone=bn_backbone, 
        bn_aspp=bn_aspp, 
        bn_decoder=bn_decoder, 
        use_in=use_in,
    )


if __name__=='__main__':
    from torch.cuda.amp import GradScaler, autocast
    device = torch.device('cuda')
    model = get_deeplab_v3_plus('resnet', num_classes=19).to(device)

    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    input_data = torch.randn((4,3,512,896)).to(device)
    label = torch.zeros(4,512,896).long().to(device)

    start.record()
    with autocast(enabled=True):
        _, output_data = model(input_data)
        output_data_upsample =  F.interpolate(output_data, size=(512,896), mode='bilinear', align_corners=True) 
        loss = nn.CrossEntropyLoss()(output_data_upsample, label)
    loss.backward()

    end.record()
    torch.cuda.synchronize()
    print("Time cost with autocast: {:.2f}".format(start.elapsed_time(end)), "ms")

    start.record()
    _, output_data = model(input_data)
    output_data_upsample =  F.interpolate(output_data, size=(512,896), mode='bilinear', align_corners=True) 
    loss = nn.CrossEntropyLoss()(output_data_upsample, label)
    loss.backward()
    end.record()
    torch.cuda.synchronize()
    print("Time cost without autocast: {:.2f}".format(start.elapsed_time(end)), "ms")

Thanks for the update! build_backbone is unfortunately not defined. Could you update the code and make it executable, so that I could try to reproduce the issue?

Sure, I have update the code in the reply above @ptrblck . It can now be run directly. And my cuda was 10.2 and torch was 1.10.1.

Hello, I have updated the code, could you please help me check the problem that leads to the speed decline :blush: