F.conv2d runs x15 slower in mixed-precision/half-precision mode

SYSTEM info:

  • OS: Ubuntu 18.04
  • CUDA: 10.1
  • pytorch: 1.4.0 installed from conda
  • NVIDIA apex: latest installed from pip
  • GPU: Titan RTX 24GB
  • Driver: 430.64
  • cudnn version: 7603

I tested my system with following code:

import torch
import torch.nn as nn
from torch import optim
from apex import amp
from torch.nn import functional as F


def main(opt_level):

    assert torch.backends.cudnn.enabled, 'Cudnn is not enabled!'
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    padding = 3 // 2
    # input - torch.Size([1, 512, 256, 256])
    # weight - torch.Size([512, 128, 3, 3])
    class net(nn.Module):
        def __init__(self):
            super().__init__()
            self.weight = nn.Parameter(torch.randn(512, 128, 3, 3))
            self.padding = 3 // 2

        def forward(self, input):
            out = F.conv2d(input, self.weight, padding=self.padding, groups=4)
            return out


    input = torch.randn(1, 512, 256, 256).float().cuda()
    model = net().cuda()
    optimizer = optim.Adam(model.parameters(), lr=0.01, betas=(0, 0.999))
    model, _ = amp.initialize(model, optimizer, keep_batchnorm_fp32=None if opt_level == 'O1' else True,
                              loss_scale='dynamic', opt_level=opt_level, num_losses=1)

    for i in range(10):
        start.record()
        out = model(input)
        end.record()
        torch.cuda.synchronize()
        print('--- %.4f' % start.elapsed_time(end))

if __name__ == '__main__':
    import argparse
    p = argparse.ArgumentParser()
    p.add_argument('-opt_level', type=str, required=True,
                    choices=['O0', 'O1', 'O2', 'O3'])
    opt = vars(p.parse_args())
    main(opt['opt_level'])

and results are:

$ python test_conv2d.py -opt_level O0

Selected optimization level O0:  Pure FP32 training.

Defaults for this optimization level are:
enabled                : True
opt_level              : O0
cast_model_type        : torch.float32
patch_torch_functions  : False
keep_batchnorm_fp32    : None
master_weights         : False
loss_scale             : 1.0
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O0
cast_model_type        : torch.float32
patch_torch_functions  : False
keep_batchnorm_fp32    : True
master_weights         : False
loss_scale             : dynamic
--- 378.2777
--- 5.2879
--- 4.7411
--- 4.6719
--- 4.6950
--- 4.6825
--- 4.7071
--- 4.6776
--- 4.6902
--- 4.6961

$ python test_conv2d.py -opt_level O1
$ python test_conv2d.py -opt_level O2
$ python test_conv2d.py -opt_level O3

Selected optimization level O2:  FP16 training with FP32 batchnorm and FP32 master weights.

Defaults for this optimization level are:
enabled                : True
opt_level              : O2
cast_model_type        : torch.float16
patch_torch_functions  : False
keep_batchnorm_fp32    : True
master_weights         : True
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O2
cast_model_type        : torch.float16
patch_torch_functions  : False
keep_batchnorm_fp32    : True
master_weights         : True
loss_scale             : dynamic
--- 446.0503
--- 63.7256
--- 62.7091
--- 62.7143
--- 61.9786
--- 60.7601
--- 60.6553
--- 60.5181
--- 60.6509
--- 60.7128

Here is documentation of opt_levels in APEX.
Opt_levels
O0: float32
O1/O2: mixed-precision
O3: half-precision

Does anyone encounter similar issue?

Grouped convolutions might not trigger the FP16 path and thus might not use TensorCores.

Thanks, problem solved.

Hi, what should I do for triggering the fp16 with multiple groups?
Should I modify pytorch source code?
I found alternative implementation (apply each row of weight separately to each batch) is much slower than conv2d with groups (5.3 ms vs 4.6 ms). Thus speed boost with fp16 (3.2 ms) is not that significant (usually should be multiplied by 2~3, in this case less than 2.3 ms).
I guess I’d better to modify the source code in pytorch to make it fully support fp16 in conv2d (with multiple groups).

Unfortunately, there is not much you can do at the moment besides using torch.backends.cudnn.benchmark = True.
I profiled the calls and all options call into the cudnn path here.
Once cudnn provides faster kernels, you will see the potential speedup.

So it depends on cudnn’s update …
Because my current model (pytorch version of stylegan2 [Karras]) uses a lot of custom conv layers(must process convolutions by groups), O1/O2 setting makes no big difference with O0 mode. Hope Nvidia updates faster kernels soon.

I also found some pytorch functions (e.g. torch.repeat, torch.cat) is slow in O1 (functions are patched as fp16) mode. In O2, these functions are faster as O0(fp32) but overall backpropagation is slower than O1.

I guess NHWC may matter …
https://devtalk.nvidia.com/default/topic/1071156/jetson-nano/poor-group-convolution-performance-in-fp16/

since it’s still on progress, I guess I could try it later when pytorch is update to date.
With these updates for pytorch I can apply my models to NHWC kernels without aplex’s updates, right?

Some PRs landed already and you could try it out on some vision models. :wink:

Yeah, found an updated example in apex :slight_smile:
Seems like latest apex is for pytorch-nightly and it requires a little more additional arguments related to memory format.
For now this looks like the best option but still hope memory format bindings could be handled only once at env setting.

I’ve test it with pytorch-nightly (1.5.0) and latest apex but this time got x1000 slower in all modes (O0, O1, O2, O3) … (with channel last)
They said it’s for volta and I am wordering whether it also works in turing …