How to reduce latency in ConvTranspose2d?

Question: How to reduce latency(computational cost) in ConvTranspose2d?

I compared latency between nn.Conv2d, Upsample(mode=nearest), ConvTranspose2d, Upsample(mode=bilinear) in different batch size. It looks like deconvolution operation takes a lot of time in large batch (=64). Is there a way to reduce computation time for deconvolution?

Thank you

Comparison Table

[------------------------------------- upsample module comparison ------------------------------------]
                        |  nearest(scale=2)  |  conv(kernel=3)  |  deconv(scale=2)  |  bilinear(scale=2)
4 threads: --------------------------------------------------------------------------------------------
      [8, 2048, 8, 8]   |        30.4        |       94.6      |        134.6      |        4838.2     
      [16, 2048, 8, 8]  |        58.4        |       94.3      |        249.4      |        9682.3     
      [32, 2048, 8, 8]  |       111.6        |      119.1      |        480.1      |       19434.2     
      [64, 2048, 8, 8]  |       219.9        |      206.1      |      12399.6      |       38844.8     

Times are in microseconds (us).

Comparison Code

import os

os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = '4'

import torch
import torch.nn as nn
import torch.utils.benchmark as benchmark
import torchvision.models as models
from itertools import product


def get_upsample_module(mode, upsample, ch):
    if mode == 'deconv':
        return nn.Sequential(
            torch.nn.ConvTranspose2d(ch, ch, upsample, stride=upsample, dilation=1, groups=ch, bias=False),
            torch.nn.BatchNorm2d(ch)
        ).cuda()
    else:
        return nn.Upsample(scale_factor=upsample, mode=mode).cuda()

def get_downsample_module(kernel, ch):
    return nn.Sequential(
        torch.nn.Conv2d(ch, ch, kernel, stride=1, padding=1, dilation=1, groups=ch, bias=False),
        torch.nn.BatchNorm2d(ch)
    ).cuda()


results = []

batch_size = [8, 16, 32, 64]
channel_size = [2048]
image_size = [8]
mode = ['nearest', 'conv', 'deconv', 'bilinear']
scale_factors = [2]

for b, c, n in product(batch_size, channel_size, image_size):
    label = 'upsample module comparison'
    sub_label = f'[{b}, {c}, {n}, {n}]'
    x = torch.rand((b, c, n, n)).cuda()

    for method, upsample in product(mode, scale_factors):
        if method == 'conv':
            upsample += 1
            model = get_downsample_module(upsample, c)
        else:
            model = get_upsample_module(method, upsample, c)
        with torch.cuda.amp.autocast():
            results.append(benchmark.Timer(
                stmt='model(x)',
                setup='from __main__ import model',
                globals={'x': x},
                label=label,
                sub_label=sub_label,
                description=f"{method}(scale={upsample})",
                num_threads=4
            ).blocked_autorange(min_run_time=1))

compare = benchmark.Compare(results)
compare.colorize()
compare.print()

My Environment:

  • OS: Linux 18.04
  • Python: 3.7.11
  • Pytorch: 1.10.0
  • CUDA: 11.3

You could add torch.backends.cudnn.benchmark = True to let cuDNN profile different kernels and select the fastest one and use to(memory_format=torch.channels_last) since you are using amp, which could give another speedup if permutations are avoided.

2 Likes

Thank you @ptrblck I solve latency problem by following your suggestion!!!

For those having simmiar problem, I recommend to use to(memory_format=torch.channels_last)

Benchmark, AMP, memory_format experiment results

nn.Conv2d , Upsample(mode=nearest) , ConvTranspose2d , Upsample(mode=bilinear) latency results on different settings.

  1. Default (bechmark=False, AMP=False, memory_format=contiguos)

    [-------------------------------------- upsample module comparison -------------------------------------]
                              |  nearest(scale=2)  |  conv(scale=3)  |  deconv(scale=2)  |  bilinear(scale=2)
    4 threads: ----------------------------------------------------------------------------------------------
          [8, 2048, 16, 16]   |       110.4        |       144.6     |        492.8      |        4891.5     
          [16, 2048, 16, 16]  |       216.6        |       276.8     |       1128.5      |        9821.6     
          [32, 2048, 16, 16]  |       433.9        |       541.6     |       2255.2      |       19672.0     
          [64, 2048, 16, 16]  |       864.3        |      1214.2     |       4535.2      |       39559.5     
    
    Times are in microseconds (us).
    
  2. Nothing changed (Benchmark=True, AMP=False, memory_format=contiguous)

    [-------------------------------------- upsample module comparison -------------------------------------]
                              |  nearest(scale=2)  |  conv(scale=3)  |  deconv(scale=2)  |  bilinear(scale=2)
    4 threads: ----------------------------------------------------------------------------------------------
          [8, 2048, 16, 16]   |       110.5        |       145.4     |        494.8      |        4906.8     
          [16, 2048, 16, 16]  |       216.9        |       277.9     |       1131.3      |        9825.4     
          [32, 2048, 16, 16]  |       434.4        |       543.0     |       2266.0      |       19766.6     
          [64, 2048, 16, 16]  |       865.5        |      1211.6     |       4539.3      |       39544.8     
    
    Times are in microseconds (us).
    
  3. Conv (:small_red_triangle_down:) Deconv (:small_red_triangle:) (Benchmark=True, AMP=True, memory_format=contiguous)

    [-------------------------------------- upsample module comparison -------------------------------------]
                              |  nearest(scale=2)  |  conv(scale=3)  |  deconv(scale=2)  |  bilinear(scale=2)
    4 threads: ----------------------------------------------------------------------------------------------
          [8, 2048, 16, 16]   |       110.5        |      162.3      |        443.0      |        4907.9     
          [16, 2048, 16, 16]  |       217.0        |      169.6      |      11440.2      |        9828.0     
          [32, 2048, 16, 16]  |       434.5        |      320.2      |      11440.9      |       19671.5     
          [64, 2048, 16, 16]  |       865.7        |      619.6      |      14415.1      |       39541.3     
    
    Times are in microseconds (us).
    
  4. Nearest (:small_red_triangle:) Conv (:small_red_triangle_down:) Deconv (:small_red_triangle:) Bilinear (:small_red_triangle_down:) (Benchmark=True, AMP=True, memory_format=channels_last)

    [-------------------------------------- upsample module comparison -------------------------------------]
                              |  nearest(scale=2)  |  conv(scale=3)  |  deconv(scale=2)  |  bilinear(scale=2)
    4 threads: ----------------------------------------------------------------------------------------------
          [8, 2048, 16, 16]   |        247.2       |      154.8      |        825.5      |         336.3     
          [16, 2048, 16, 16]  |        492.0       |      170.3      |       1677.0      |         675.8     
          [32, 2048, 16, 16]  |        979.6       |      327.0      |       3498.6      |        1347.3     
          [64, 2048, 16, 16]  |       1971.4       |      675.1      |       6861.0      |        2703.2     
    
    Times are in microseconds (us).
    
  5. Nearest (:small_red_triangle:) Conv (:small_red_triangle_down:) Deconv (:small_red_triangle_down:) Bilinear (:small_red_triangle_down:) (Benchmark=True, AMP=False, memory_format=channels_last)

    [-------------------------------------- upsample module comparison -------------------------------------]
                              |  nearest(scale=2)  |  conv(scale=3)  |  deconv(scale=2)  |  bilinear(scale=2)
    4 threads: ----------------------------------------------------------------------------------------------
          [8, 2048, 16, 16]   |        246.2       |      104.8      |        495.6      |         332.0     
          [16, 2048, 16, 16]  |        486.4       |      197.8      |        973.1      |         665.7     
          [32, 2048, 16, 16]  |        963.8       |      406.1      |       1820.7      |        1328.6     
          [64, 2048, 16, 16]  |       1954.1       |      834.5      |       3646.1      |        2682.6     
    
    Times are in microseconds (us).
    

For those who want to experiment with other modules

This is script that I used to compare latency.

import os

os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = '4'

import torch
import torch.nn as nn
import torch.utils.benchmark as benchmark
import torchvision.models as models
from itertools import product


def get_upsample_module(mode, upsample, ch):
    if mode == 'deconv':
        return nn.Sequential(
            torch.nn.ConvTranspose2d(ch, ch, upsample, stride=upsample, dilation=1, groups=ch, bias=False),
            torch.nn.BatchNorm2d(ch)
        ).cuda()
    else:
        return nn.Upsample(scale_factor=upsample, mode=mode).cuda()

def get_downsample_module(kernel, ch):
    return nn.Sequential(
        torch.nn.Conv2d(ch, ch, kernel, stride=1, padding=1, dilation=1, groups=ch, bias=False),
        torch.nn.BatchNorm2d(ch)
    ).cuda()

use_benchmark = True
amp = False
channel_last = True

torch.backends.cudnn.benchmark = use_benchmark

results = []

batch_size = [8, 16, 32, 64]
channel_size = [2048]
image_size = [16]
mode = ['nearest', 'conv', 'deconv', 'bilinear']
scale_factors = [2]

for b, c, n in product(batch_size, channel_size, image_size):
    label = 'upsample module comparison'
    sub_label = f'[{b}, {c}, {n}, {n}]'
    x = torch.rand((b, c, n, n)).cuda()

    for method, upsample in product(mode, scale_factors):
        if method == 'conv':
            upsample += 1
            model = get_downsample_module(upsample, c)
        else:
            model = get_upsample_module(method, upsample, c)

        if channel_last:
            x = x.to(memory_format=torch.channels_last)
            model = model.to(memory_format=torch.channels_last)
        
        with torch.autocast('cuda', amp):
            results.append(benchmark.Timer(
                stmt='model(x)',
                setup='from __main__ import model',
                globals={'x': x},
                label=label,
                sub_label=sub_label,
                description=f"{method}(scale={upsample})",
                num_threads=4
            ).blocked_autorange(min_run_time=1))

compare = benchmark.Compare(results)
compare.colorize()
compare.print()