Using optimised depthwise convolutions


#1

Hi all,

Following #3057 and #3265, I was excited to try out depthwise separable convolutions, but I’m having a hard time activating these optimised code paths. I’m currently getting no speedup over default convolutions.

Here are the two layer types that make up the bulk of my network:

# Depthwise
nn.Conv2d(in_chans, in_chans * k, kernel_size, groups = in_chans)

# Normal
nn.Conv2d(in_chans * k, out_chans, 1)

If I profile the network’s execution, I get the following (trimmed):

-------------------------  ------------  ------------  ------------
Name                           CPU time     CUDA time         Calls
-------------------------  ------------  ------------  ------------
conv2d                        130.737us    1016.190us            14
cudnn_convolution             160.005us     843.504us             8
thnn_conv_depthwise2d          58.475us    1246.438us             6

My concern is that the depthwise convolutions are being handled by THNN, and not THCUNN, where the new optimisations are. I have a second network which replaces normal convolution with multiple depthwise convolutions, and runs much slower, despite performing much less computation.

Am I missing something obvious?

>>> torch.__version__
'0.4.0a0+82e995e'
>>> torch.version.cuda
'8.0.61'
>>> torch.backends.cudnn.version()
7005

#2

are you profiling correctly?

If you are using python timing calls, you have to insert cuda synchronize calls to get correct timing:

torch.cuda.synchronize()
now = time.time()
# my cuda calls
torch.cuda.synchronize()
total_time = time.time() - now

#3

I’m using the following, from the docs:

with torch.cuda.profiler.profile():
    model(input_var) # Warmup CUDA memory allocator and profiler
    with torch.autograd.profiler.emit_nvtx():
        output = model(input_var)

#4

hmmm, okay. what are your layer and input sizes?


#5

I have those two layers wrapped up to create a drop-in replacement for nn.Conv2d, and they’re used in CIFAR-10 and ImageNet scale networks. The particular network that I profiled looks like this:

Input: 256x3x32x32

Conv2d(  3, 128, kernel_size=3, padding=1)

Conv2d(128, 128, kernel_size=3, padding=1, groups=128)
Conv2d(128, 128, kernel_size=1, padding=0)

Conv2d(128, 128, kernel_size=3, padding=1, groups=128)
Conv2d(128, 128, kernel_size=1, padding=0)
        
MaxPool2d(kernel_size=3, stride=2)
# Channels now sized 15x15
        
Conv2d(128, 128, kernel_size=3, padding=1, groups=128)
Conv2d(128, 128, kernel_size=1, padding=0)

Conv2d(128, 128, kernel_size=3, padding=1, groups=128)
Conv2d(128, 128, kernel_size=1, padding=0)
        
MaxPool2d(kernel_size=3, stride=2)
# Channels now sized 7x7
        
Conv2d(128, 128, kernel_size=3, padding=1, groups=128)
Conv2d(128, 128, kernel_size=1, padding=0)

Conv2d(128, 128, kernel_size=3, padding=1, groups=128)
Conv2d(128, 128, kernel_size=1, padding=0)
        
AvgPool2d(kernel_size=7)
# Channels now sized 1x1
        
Conv2d(128,  10, kernel_size=1, padding=0)

I’ve also tried the wrapped layers in copies of AlexNet, SqueezeNet and ResNet from the torchvision module. I haven’t profiled those. All trained to a reasonable accuracy, but each iteration was slower than their default counterparts.


#6

Here’s a simple test script I wrote which executes a single conv layer.

I had to drop back to PyTorch v0.3 so I no longer see which backend is active, but I haven’t yet found a configuration where grouped convolution is markedly faster than ordinary convolution. e.g.:

$ python3 group_conv.py --batch=256 --img-size=50 --in-channels=256 --out-channels=256 --kernel-size=3 --groups=True
---------------  ---------------  ---------------  ---------------  ---------------  ---------------
Name                    CPU time        CUDA time            Calls        CPU total       CUDA total
---------------  ---------------  ---------------  ---------------  ---------------  ---------------
ConvForward             77.271us      16688.654us                1         77.271us      16688.654us

$ python3 group_conv.py --batch=256 --img-size=50 --in-channels=256 --out-channels=256 --kernel-size=3 --groups=False
---------------  ---------------  ---------------  ---------------  ---------------  ---------------
Name                    CPU time        CUDA time            Calls        CPU total       CUDA total
---------------  ---------------  ---------------  ---------------  ---------------  ---------------
ConvForward             72.029us      15207.838us                1         72.029us      15207.838us

(Ziju Feng) #7

I’ve seen similar results when working with 3D ResNext models where group convolution is used extensively (in group size of 32). In 3D, the ResNext model actually only has around 50% of parameters (and thus flops) compared to the Resnet counterpart but the forward/backward pass is 25%/60% slower with batch size 1 and 3% and 10% slower with batch size 128 split to 4 GPUs.


#8

I’ve had some offline discussions about this. Here are my findings:

  • thnn_conv_depthwise2d, the internal function called, does use CUDA, but not CuDNN.
  • CuDNN 7’s implementation of grouped/depthwise convolution is up to 3x quicker in the forward pass, but always slower in the backward pass.
  • Choosing when to use CuDNN and when not to is very difficult to describe in a maintainable way. It will involve looking at many of the parameters of the conv layer, the current version of CuDNN, and is probably dependent on the GPU used as well.

I think it’s probably a case of waiting until CUDA/CuDNN provide consistent benefits in more situations.


#9

nvidia said they will integrate some dedicated kernels similar to pytorch’s group conv into the next cudnn version and choose it when it’s faster.


(Ziju Feng) #10

Would the next version of cudnn mean a major/minor version update like 8.0/7.1 or a patch level update like 7.0.6? Did Nvidia give us an estimate on the release date? Thanks.


#11

Hello, based on the nvidia releases notes, some depthwise seperable convolution improvements have now made it into cudnn:

https://docs.nvidia.com/deeplearning/sdk/cudnn-release-notes/rel_730.html#rel_730

Is this now integrated with pytorch?


#12

@smth, @Ziju_Feng? I had a look at the source code & it seems the THNN implementation is still being used?


#13

@yaysummeriscoming I’ve checked the code as well, and yes even in the case of depthwise we are still dispatching to thnn depthwise kernels: https://github.com/pytorch/pytorch/blob/51f1c4fea5e066e3cbb7a7251114de4381d2a926/aten/src/ATen/native/Convolution.cpp#L330

We haven’t used CuDNN implementation because those kernels aren’t very optimized according to an NVIDIA employee: https://github.com/pytorch/pytorch/issues/1708#issuecomment-384015108

Once the situation changes and CuDNN has faster and optimized convolution routines for Depthwise, we will switch to dispatching to CuDNN.

I’ve opened and issue to track this: https://github.com/pytorch/pytorch/issues/15513


(Nilay Shrivastava) #14

Even when I compare the performance of depthwise separable conv3d vs the normal conv3d without gpu, I find that the latter still outperforms the former.