NVIDIA Tensor Cores not being used for nn.Conv3d (3D Convolutions)

Hi,
TLDR; “Conv2d uses tensor cores, Conv3D doesn’t when using apex AMP or FP16”

According to NVIDIA, cudnn 8 does support tensor core operations for 3D convolutions.
To be sure if tensor cores are really being used (HMMA instructions) I am checking this with the nvidia profiler with the sm__inst_executed_pipe_tensor_op_hmma.sum metrics activated (See https://developer.nvidia.com/blog/using-nsight-compute-nvprof-mixed-precision-deep-learning-models/)

I am really up to date by building pytorch from source:

My setup:
PyTorch: 1.7.0a
cudnn: 8.0.1
CUDA: 10.2
GPU: RTX 2080 Ti

Simple test script:

import torch
from torch import nn
import torch.nn.functional as F
from time import time
import logging
from apex import amp

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv = nn.Conv3d(64, 64, 5, padding=2)
        self.pool = nn.MaxPool3d(2, 2)
        
    def forward(self, x):
        x = self.pool(F.relu(self.conv(x)))
        return x


if __name__ == '__main__':
    net = Net().to("cuda")
    amp_net = amp.initialize(net, opt_level="O1")

    input_image = torch.rand((8, 64, 64, 64, 64), device="cuda")
    gpu_net = amp_net

    gpu_net(input_image)

Behavior:
If I change the network to 2D input and 2D convolutoins, I can see that tensorcores are used in the metrics and that there is a speed-up.
For 3D they are not being used.

EDIT: Instead of using AMP, I also tried converting the net and inputs to FP16, which also worked in the 2D case but not in 3D

Maybe @ptrblck can help?

Thanks for the code snippet. I’ll try to grab a node tomorrow and profile it with the latest cudnn8 version, as the release candidate (8.0.1) might not pick the right kernels.

1 Like

Thanks, sounds great. The profiler tells me that it’s using the following kernel for convolution:
void implicit convolveNd sgemm

1 Like

Quick update:
I verified that TensorCores are used for this workload in the latest cudnn version on a V100.
Unfortunately, I couldn’t get an RTX2080Ti node yet and will try it tomorrow again, sorry. :confused:

Ok that’s weird, can you let me know which exact versions of cudnn, torch, CUDA, you used?
Was it the same test code?

Btw: The final production code will run on a T4 card. I haven’t tried to verify tensor-op use on that GPU though.

In the meantime, I tried to use cudnn from C++ directly to verify Tensor Cores work in general on this computer and setup.
I was able to run the cudnnConvolutionForward cudnn method with 3D image and filter and the profiler confirms it’s using the xmma new::gemm::kernel which makes extensive use of tensor cores as I can see from the metrics.
So in general it should work. Haven’t managed from pytorch though…

I used internal CUDA and cudnn versions for this test.
Could you post the complete model architecture (if possible) or at least some more Conv3d workloads so that I could profile them and make sure that TensorCores are used?

Hi,
I cannot share the complete (and rather complex) model but I extended the minimal example from above a bit to make the effect more obvious (it’s attached below).

I am not sure about the padding though, I would assume the input shape dimensions need to be divisable by 8, hence I first use a normal conv on the 64x64 input then use padded conv.
But I also tried other combinations…

Running the code in 2D mode I get:

2020-07-27 08:12:32,971 - Executing 2D Test
2020-07-27 08:12:36,583 - FP32 duration: 2.420s
2020-07-27 08:12:37,398 - FP16 duration: 0.814s
2020-07-27 08:12:38,229 - AMP duration: 0.828s

When profiling, I set the rounds to 1 to make the report more readable (and because profiling is slow).
The profiler shows tensor core usage for FP16 and AMP part

Running in 3D mode:

2020-07-27 08:25:13,797 - Executing 3D Test
2020-07-27 08:25:17,523 - FP32 duration: 2.203s
2020-07-27 08:25:21,319 - FP16 duration: 3.790s
2020-07-27 08:25:25,137 - AMP duration: 3.811s

For 3D, the profiler doesn’t show any tensor core usage.

I used the nv-nsight-cu GUI profiler and configured it so that it calls the command line profiler with the following arguments: /opt/nvidia/nsight-compute/2019.5.0/target/linux-desktop-glibc_2_11_3-x64/nv-nsight-cu-cli --export "report" --force-overwrite --target-processes all --kernel-regex-base function --launch-skip-before-match 0 --section LaunchStats --section Occupancy --section SpeedOfLight --sampling-interval auto --sampling-max-passes 5 --sampling-buffer-size 33554432 --profile-from-start 1 --cache-control all --clock-control base --apply-rules --metrics sm__inst_executed_pipe_tensor_op_hmma.sum "conda/envs/pytorchcudnn8/bin/python" torch-try-tensorcores_minimal2d3d.py

The code:

import torch
from torch import nn
import torch.nn.functional as F
from time import time
import logging
from apex import amp


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv = nn_conv(64, 64, 5)
        self.conv_padded = nn_conv(64, 64, 5, padding=2)
        self.pool = nn_pool(2, 2)

    def forward(self, x):
        x = F.relu(self.conv(x))
        for i in range(5):
            x = F.relu(self.conv_padded(x))
        return x


if __name__ == "__main__":
    logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.DEBUG)
    # === CONFIGURATION ===
    dimensionality = 2  # 2 for 2D, 3 for 3D

    # === TEST CODE ===
    if dimensionality == 2:
        logging.info(f"Executing 2D Test")
        rounds = 500
        nn_conv = nn.Conv2d
        nn_pool = nn.MaxPool2d
        image_dims = (8, 64, 64, 64)
    else:
        logging.info(f"Executing 3D Test")
        rounds = 2
        nn_conv = nn.Conv3d
        nn_pool = nn.MaxPool3d
        image_dims = (8, 64, 64, 64, 64)

    input_image = torch.rand(image_dims, device="cuda")

    # Run in FP32 mode
    net = Net().to("cuda").to(torch.float32)
    start = time()
    for i in range(rounds):
        net(input_image)
    torch.cuda.synchronize()
    logging.info(f"FP32 duration: {time() - start:.03f}s")

    # Run in FP16 mode
    net = Net().to("cuda").to(torch.float32)
    fp16_net = net.to(torch.float16)
    fp16_input = input_image.to(torch.float16)
    start = time()
    for i in range(rounds):
        fp16_net(fp16_input)
    torch.cuda.synchronize()
    logging.info(f"FP16 duration: {time() - start:.03f}s")

    # Run with AMP
    net = Net().to("cuda").to(torch.float32)
    amp_net = amp.initialize(net, opt_level="O1", verbosity=0)
    start = time()
    for i in range(rounds):
        amp_net(input_image)
    torch.cuda.synchronize()
    logging.info(f"AMP duration: {time() - start:.03f}s")
1 Like

If you have been able to make it work for 3D, let me know how so I can try to reproduce it :wink:

That’s helpful! Thanks for the code. I’ll profile it and check, if TCs are used or not.

@ptrblck: I debugged into pytorch source code to find the reason and I traced it down to the cudnn statement where the convolution descriptor is populated:

aten/src/ATen/cudnn/Descriptors.h:170

cudnnSetConvolutionNdDescriptor(mut_desc(), dim, pad, stride, upscale, CUDNN_CROSS_CORRELATION, mathType);

(https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/cudnn/Descriptors.h#L170)

When changing the mode from CUDNN_CROSS_CORRELATION to CUDNN_CONVOLUTION, it will work and it will use tensor cores. I assume this is a bug inside cudnn, since both modes should work equally well, just like they do in 2D.

I verified this effect using my completely separate C++ code which uses cudnn directly. Just by changing the convolution mode, I can see tensor cores turning off or on, keeping all other settings.

EDIT: So when I used my custom pytorch version with CUDNN_CONVOLUTION, I still needed to modify the above example code slightly, because the convolution workspace was to large. But if you change the batch size from 8 to 1 it should work.

Solution:

Verified that this is a bug inside cudnn 8.0.1 via official bug report.
Building torch from source with cudnn 8.0.2 fixes the problem.

2 Likes