NVIDIA Tensor Cores not being used for nn.Conv3d (3D Convolutions)

Hi,
TLDR; “Conv2d uses tensor cores, Conv3D doesn’t when using apex AMP or FP16”

According to NVIDIA, cudnn 8 does support tensor core operations for 3D convolutions.
To be sure if tensor cores are really being used (HMMA instructions) I am checking this with the nvidia profiler with the sm__inst_executed_pipe_tensor_op_hmma.sum metrics activated (See https://developer.nvidia.com/blog/using-nsight-compute-nvprof-mixed-precision-deep-learning-models/)

I am really up to date by building pytorch from source:

My setup:
PyTorch: 1.7.0a
cudnn: 8.0.1
CUDA: 10.2
GPU: RTX 2080 Ti

Simple test script:

import torch
from torch import nn
import torch.nn.functional as F
from time import time
import logging
from apex import amp

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv = nn.Conv3d(64, 64, 5, padding=2)
        self.pool = nn.MaxPool3d(2, 2)
        
    def forward(self, x):
        x = self.pool(F.relu(self.conv(x)))
        return x


if __name__ == '__main__':
    net = Net().to("cuda")
    amp_net = amp.initialize(net, opt_level="O1")

    input_image = torch.rand((8, 64, 64, 64, 64), device="cuda")
    gpu_net = amp_net

    gpu_net(input_image)

Behavior:
If I change the network to 2D input and 2D convolutoins, I can see that tensorcores are used in the metrics and that there is a speed-up.
For 3D they are not being used.

EDIT: Instead of using AMP, I also tried converting the net and inputs to FP16, which also worked in the 2D case but not in 3D

Maybe @ptrblck can help?

Thanks for the code snippet. I’ll try to grab a node tomorrow and profile it with the latest cudnn8 version, as the release candidate (8.0.1) might not pick the right kernels.

Thanks, sounds great. The profiler tells me that it’s using the following kernel for convolution:
void implicit convolveNd sgemm

Quick update:
I verified that TensorCores are used for this workload in the latest cudnn version on a V100.
Unfortunately, I couldn’t get an RTX2080Ti node yet and will try it tomorrow again, sorry. :confused:

Ok that’s weird, can you let me know which exact versions of cudnn, torch, CUDA, you used?
Was it the same test code?

Btw: The final production code will run on a T4 card. I haven’t tried to verify tensor-op use on that GPU though.

In the meantime, I tried to use cudnn from C++ directly to verify Tensor Cores work in general on this computer and setup.
I was able to run the cudnnConvolutionForward cudnn method with 3D image and filter and the profiler confirms it’s using the xmma new::gemm::kernel which makes extensive use of tensor cores as I can see from the metrics.
So in general it should work. Haven’t managed from pytorch though…

I used internal CUDA and cudnn versions for this test.
Could you post the complete model architecture (if possible) or at least some more Conv3d workloads so that I could profile them and make sure that TensorCores are used?

Hi,
I cannot share the complete (and rather complex) model but I extended the minimal example from above a bit to make the effect more obvious (it’s attached below).

I am not sure about the padding though, I would assume the input shape dimensions need to be divisable by 8, hence I first use a normal conv on the 64x64 input then use padded conv.
But I also tried other combinations…

Running the code in 2D mode I get:

2020-07-27 08:12:32,971 - Executing 2D Test
2020-07-27 08:12:36,583 - FP32 duration: 2.420s
2020-07-27 08:12:37,398 - FP16 duration: 0.814s
2020-07-27 08:12:38,229 - AMP duration: 0.828s

When profiling, I set the rounds to 1 to make the report more readable (and because profiling is slow).
The profiler shows tensor core usage for FP16 and AMP part

Running in 3D mode:

2020-07-27 08:25:13,797 - Executing 3D Test
2020-07-27 08:25:17,523 - FP32 duration: 2.203s
2020-07-27 08:25:21,319 - FP16 duration: 3.790s
2020-07-27 08:25:25,137 - AMP duration: 3.811s

For 3D, the profiler doesn’t show any tensor core usage.

I used the nv-nsight-cu GUI profiler and configured it so that it calls the command line profiler with the following arguments: /opt/nvidia/nsight-compute/2019.5.0/target/linux-desktop-glibc_2_11_3-x64/nv-nsight-cu-cli --export "report" --force-overwrite --target-processes all --kernel-regex-base function --launch-skip-before-match 0 --section LaunchStats --section Occupancy --section SpeedOfLight --sampling-interval auto --sampling-max-passes 5 --sampling-buffer-size 33554432 --profile-from-start 1 --cache-control all --clock-control base --apply-rules --metrics sm__inst_executed_pipe_tensor_op_hmma.sum "conda/envs/pytorchcudnn8/bin/python" torch-try-tensorcores_minimal2d3d.py

The code:

import torch
from torch import nn
import torch.nn.functional as F
from time import time
import logging
from apex import amp


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv = nn_conv(64, 64, 5)
        self.conv_padded = nn_conv(64, 64, 5, padding=2)
        self.pool = nn_pool(2, 2)

    def forward(self, x):
        x = F.relu(self.conv(x))
        for i in range(5):
            x = F.relu(self.conv_padded(x))
        return x


if __name__ == "__main__":
    logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.DEBUG)
    # === CONFIGURATION ===
    dimensionality = 2  # 2 for 2D, 3 for 3D

    # === TEST CODE ===
    if dimensionality == 2:
        logging.info(f"Executing 2D Test")
        rounds = 500
        nn_conv = nn.Conv2d
        nn_pool = nn.MaxPool2d
        image_dims = (8, 64, 64, 64)
    else:
        logging.info(f"Executing 3D Test")
        rounds = 2
        nn_conv = nn.Conv3d
        nn_pool = nn.MaxPool3d
        image_dims = (8, 64, 64, 64, 64)

    input_image = torch.rand(image_dims, device="cuda")

    # Run in FP32 mode
    net = Net().to("cuda").to(torch.float32)
    start = time()
    for i in range(rounds):
        net(input_image)
    torch.cuda.synchronize()
    logging.info(f"FP32 duration: {time() - start:.03f}s")

    # Run in FP16 mode
    net = Net().to("cuda").to(torch.float32)
    fp16_net = net.to(torch.float16)
    fp16_input = input_image.to(torch.float16)
    start = time()
    for i in range(rounds):
        fp16_net(fp16_input)
    torch.cuda.synchronize()
    logging.info(f"FP16 duration: {time() - start:.03f}s")

    # Run with AMP
    net = Net().to("cuda").to(torch.float32)
    amp_net = amp.initialize(net, opt_level="O1", verbosity=0)
    start = time()
    for i in range(rounds):
        amp_net(input_image)
    torch.cuda.synchronize()
    logging.info(f"AMP duration: {time() - start:.03f}s")

If you have been able to make it work for 3D, let me know how so I can try to reproduce it :wink:

That’s helpful! Thanks for the code. I’ll profile it and check, if TCs are used or not.

@ptrblck: I debugged into pytorch source code to find the reason and I traced it down to the cudnn statement where the convolution descriptor is populated:

aten/src/ATen/cudnn/Descriptors.h:170

cudnnSetConvolutionNdDescriptor(mut_desc(), dim, pad, stride, upscale, CUDNN_CROSS_CORRELATION, mathType);

(https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/cudnn/Descriptors.h#L170)

When changing the mode from CUDNN_CROSS_CORRELATION to CUDNN_CONVOLUTION, it will work and it will use tensor cores. I assume this is a bug inside cudnn, since both modes should work equally well, just like they do in 2D.

I verified this effect using my completely separate C++ code which uses cudnn directly. Just by changing the convolution mode, I can see tensor cores turning off or on, keeping all other settings.

EDIT: So when I used my custom pytorch version with CUDNN_CONVOLUTION, I still needed to modify the above example code slightly, because the convolution workspace was to large. But if you change the batch size from 8 to 1 it should work.

Solution:

Verified that this is a bug inside cudnn 8.0.1 via official bug report.
Building torch from source with cudnn 8.0.2 fixes the problem.