Channels_last format convolution is slower than normal NCHW

Hi, I’m experimenting the different memory layouts based on these two documentation:

Convolutional Layers User Guide (from NVIDIA)
CHANNELS LAST MEMORY FORMAT IN PYTORCH (from Pytorch official doc)

I tried to compare the NCHW model with the NHWC model with the following scripts:

from time import time

import torch
import torch.nn as nn


def time_layer(layer, feature, num_iter):
    """Time the total time used in num_iter forwarding."""
    tic = time()
    for _ in range(num_iter):
        _  = layer(feature)
    print(time() - tic, "seconds")


N, C, H, W, K = 32, 1024, 7, 7, 1024    # params from the NVIDIA doc
# NCHW tensor & layer
a = torch.empty(N, C, H, W, device="cuda:0")
conv_nchw = nn.Conv2d(C, K, 3, 1, 1).to("cuda:0")

# NHWC tensor & layer
b = torch.empty(N, C, H, W, device="cuda:0", memory_format=torch.channels_last)
conv_nhwc = nn.Conv2d(C, K, 3, 1, 1).to("cuda:0", memory_format=torch.channels_last)

# NCHW kernel & NCHW tensor
time_layer(conv_nchw, a, 1000)

# NCHW kernel & NHWC tensor
time_layer(conv_nchw, b, 1000)

# NHWC kernel & NHWC tensor
time_layer(conv_nhwc, b, 1000)

# NHWC kernel & NCHW tensor
time_layer(conv_nhwc, a, 1000)

And I got the following output (results looked similar in many repeated runs):

0.9735202789306641 seconds       # NCHW kernel & NCHW tensor
2.213291645050049 seconds        # NCHW kernel & NHWC tensor
2.3461294174194336 seconds       # NHWC kernel & NHWC tensor
2.7654671669006348 seconds       # NHWC kernel & NCHW tensor

I’m using a TITAN RTX GPU which is supposed to have Tensor Core and Pytorch 1.7.0+cu101 which supports channels_last format. So, it’s surprising to see that the fastest timing happens with NCHW kernel & NCHW tensor combination (which won’t be as surprising if I don’t have Tensor Core on my GPU because I guess NCHW format was the one that’s optimized). It’s not so surprising with NCHW kernel & NHWC tensor and NHWC kernel & NCHW tensor combinations because mixing up the format is certainly no good to the computation. However, why is NHWC kernel & NHWC tensor not the fastest combination which is supposed to be the most optimized one with Tensor Core?

Am I doing the layout optimization correctly? Am I missing anything?

Follow up question: instead of running all 4 benchmarks in a script, I executed the 4 lines in the python interpreter interactively, line-by-line, and got (results looked similar in many repeated runs):

>>> time_layer(conv_nchw, a, 1000)       # NCHW kernel & NCHW tensor
0.9541912078857422 seconds
>>> time_layer(conv_nchw, b, 1000)       # NCHW kernel & NHWC tensor
2.034724235534668 seconds
>>> time_layer(conv_nhwc, b, 1000)       # NHWC kernel & NHWC tensor
1.7101032733917236 seconds
>>> time_layer(conv_nhwc, a, 1000)       # NHWC kernel & NCHW tensor
1.9565918445587158 seconds

Why are the latter 3 timings shorter than those executed in the stream-lined script? Only thing I can think of is that when executed in the interactive interpreter, I made noticeable time gaps between two executions while the script didn’t have such gaps. Are there any nuances related to this?

I you could answer I’d really appreciate the help!

@CDhere For benchmarking, should probably place a torch.cuda.synchronize(device="cuda:0") before you print. Also, the channels_last speedups for convs are most relevant for float16. You might want to set a, b and the convs dtype=torch.float16

If there is odd float32 behaviour like you see, could be a regression in 1.7 builds, I saw several issues that have been fixed in more recent NGC (20.12) container releases that build against newer versions of cuDNN/CUDA.