Using optimised depthwise convolutions

Hi all,

Following #3057 and #3265, I was excited to try out depthwise separable convolutions, but I’m having a hard time activating these optimised code paths. I’m currently getting no speedup over default convolutions.

Here are the two layer types that make up the bulk of my network:

# Depthwise
nn.Conv2d(in_chans, in_chans * k, kernel_size, groups = in_chans)

# Normal
nn.Conv2d(in_chans * k, out_chans, 1)

If I profile the network’s execution, I get the following (trimmed):

-------------------------  ------------  ------------  ------------
Name                           CPU time     CUDA time         Calls
-------------------------  ------------  ------------  ------------
conv2d                        130.737us    1016.190us            14
cudnn_convolution             160.005us     843.504us             8
thnn_conv_depthwise2d          58.475us    1246.438us             6

My concern is that the depthwise convolutions are being handled by THNN, and not THCUNN, where the new optimisations are. I have a second network which replaces normal convolution with multiple depthwise convolutions, and runs much slower, despite performing much less computation.

Am I missing something obvious?

>>> torch.__version__
'0.4.0a0+82e995e'
>>> torch.version.cuda
'8.0.61'
>>> torch.backends.cudnn.version()
7005

are you profiling correctly?

If you are using python timing calls, you have to insert cuda synchronize calls to get correct timing:

torch.cuda.synchronize()
now = time.time()
# my cuda calls
torch.cuda.synchronize()
total_time = time.time() - now
1 Like

I’m using the following, from the docs:

with torch.cuda.profiler.profile():
    model(input_var) # Warmup CUDA memory allocator and profiler
    with torch.autograd.profiler.emit_nvtx():
        output = model(input_var)

hmmm, okay. what are your layer and input sizes?

I have those two layers wrapped up to create a drop-in replacement for nn.Conv2d, and they’re used in CIFAR-10 and ImageNet scale networks. The particular network that I profiled looks like this:

Input: 256x3x32x32

Conv2d(  3, 128, kernel_size=3, padding=1)

Conv2d(128, 128, kernel_size=3, padding=1, groups=128)
Conv2d(128, 128, kernel_size=1, padding=0)

Conv2d(128, 128, kernel_size=3, padding=1, groups=128)
Conv2d(128, 128, kernel_size=1, padding=0)
        
MaxPool2d(kernel_size=3, stride=2)
# Channels now sized 15x15
        
Conv2d(128, 128, kernel_size=3, padding=1, groups=128)
Conv2d(128, 128, kernel_size=1, padding=0)

Conv2d(128, 128, kernel_size=3, padding=1, groups=128)
Conv2d(128, 128, kernel_size=1, padding=0)
        
MaxPool2d(kernel_size=3, stride=2)
# Channels now sized 7x7
        
Conv2d(128, 128, kernel_size=3, padding=1, groups=128)
Conv2d(128, 128, kernel_size=1, padding=0)

Conv2d(128, 128, kernel_size=3, padding=1, groups=128)
Conv2d(128, 128, kernel_size=1, padding=0)
        
AvgPool2d(kernel_size=7)
# Channels now sized 1x1
        
Conv2d(128,  10, kernel_size=1, padding=0)

I’ve also tried the wrapped layers in copies of AlexNet, SqueezeNet and ResNet from the torchvision module. I haven’t profiled those. All trained to a reasonable accuracy, but each iteration was slower than their default counterparts.

Here’s a simple test script I wrote which executes a single conv layer.

I had to drop back to PyTorch v0.3 so I no longer see which backend is active, but I haven’t yet found a configuration where grouped convolution is markedly faster than ordinary convolution. e.g.:

$ python3 group_conv.py --batch=256 --img-size=50 --in-channels=256 --out-channels=256 --kernel-size=3 --groups=True
---------------  ---------------  ---------------  ---------------  ---------------  ---------------
Name                    CPU time        CUDA time            Calls        CPU total       CUDA total
---------------  ---------------  ---------------  ---------------  ---------------  ---------------
ConvForward             77.271us      16688.654us                1         77.271us      16688.654us

$ python3 group_conv.py --batch=256 --img-size=50 --in-channels=256 --out-channels=256 --kernel-size=3 --groups=False
---------------  ---------------  ---------------  ---------------  ---------------  ---------------
Name                    CPU time        CUDA time            Calls        CPU total       CUDA total
---------------  ---------------  ---------------  ---------------  ---------------  ---------------
ConvForward             72.029us      15207.838us                1         72.029us      15207.838us

I’ve seen similar results when working with 3D ResNext models where group convolution is used extensively (in group size of 32). In 3D, the ResNext model actually only has around 50% of parameters (and thus flops) compared to the Resnet counterpart but the forward/backward pass is 25%/60% slower with batch size 1 and 3% and 10% slower with batch size 128 split to 4 GPUs.

I’ve had some offline discussions about this. Here are my findings:

  • thnn_conv_depthwise2d, the internal function called, does use CUDA, but not CuDNN.
  • CuDNN 7’s implementation of grouped/depthwise convolution is up to 3x quicker in the forward pass, but always slower in the backward pass.
  • Choosing when to use CuDNN and when not to is very difficult to describe in a maintainable way. It will involve looking at many of the parameters of the conv layer, the current version of CuDNN, and is probably dependent on the GPU used as well.

I think it’s probably a case of waiting until CUDA/CuDNN provide consistent benefits in more situations.

nvidia said they will integrate some dedicated kernels similar to pytorch’s group conv into the next cudnn version and choose it when it’s faster.

Would the next version of cudnn mean a major/minor version update like 8.0/7.1 or a patch level update like 7.0.6? Did Nvidia give us an estimate on the release date? Thanks.

Hello, based on the nvidia releases notes, some depthwise seperable convolution improvements have now made it into cudnn:

https://docs.nvidia.com/deeplearning/sdk/cudnn-release-notes/rel_730.html#rel_730

Is this now integrated with pytorch?

@smth, @Ziju_Feng? I had a look at the source code & it seems the THNN implementation is still being used?

@yaysummeriscoming I’ve checked the code as well, and yes even in the case of depthwise we are still dispatching to thnn depthwise kernels: https://github.com/pytorch/pytorch/blob/51f1c4fea5e066e3cbb7a7251114de4381d2a926/aten/src/ATen/native/Convolution.cpp#L330

We haven’t used CuDNN implementation because those kernels aren’t very optimized according to an NVIDIA employee: https://github.com/pytorch/pytorch/issues/1708#issuecomment-384015108

Once the situation changes and CuDNN has faster and optimized convolution routines for Depthwise, we will switch to dispatching to CuDNN.

I’ve opened and issue to track this: https://github.com/pytorch/pytorch/issues/15513

Even when I compare the performance of depthwise separable conv3d vs the normal conv3d without gpu, I find that the latter still outperforms the former.

I think the problem might be memory bandwidth.

if you implement Depthwise Separable convolution something like this:

class SeparableConv2d(nn.Module):
    def __init__(self,in_channels,out_channels,kernel_size=1,stride=1,padding=0,dilation=1,bias=False):
        super(SeparableConv2d,self).__init__()

        self.conv1 = nn.Conv2d(in_channels,in_channels,kernel_size,stride,padding,dilation,groups=in_channels,bias=bias)
        self.pointwise = nn.Conv2d(in_channels,out_channels,1,1,0,1,1,bias=bias)
    
    def forward(self,x):
        x = self.conv1(x)
        x = self.pointwise(x)
        return x

Then your layer winds up needing to write to twice as much memory compared to a typical conv2d.

This memory usage shouldn’t be necessary though. If the code was written as a single layer in CUDA, each pixel of the task could be done together and only the final result would need to be written to vram. This would likely result in a massive speedup for this operation.