Performing convolutions in groups (but not grouped convolution!)

Not quite. This profile definitely suggests it should, but in the actual network I am running (a fairly simple 16-layer encoder/decoder), a batch with the cudnn benchmark on still takes ~3.6s on average after 100’s of batches have been run through. That is, after the algorithm selection should have been optimized, it still takes 3.6s / batch (forward+backward). My implementation drops that down to ~1.6s / batch (forward+backward). I still can’t account for the difference yet. This is with an input size of 3 x 20480 x 3 x 4 x 4 (B x P x C x H x W).

I have yet to try the unfold method in the network, so it’s also possible that will end up winning in total. I’m going to try that now.

All in all, it seems silly that these methods should run so much less efficiently when compared over sequential calls given that they all profile so similarly on an individual basis.

If you’re feeling adventurous (or particularly bored), feel free to see what kind of speed you can get. I’d be curious to know if you discover a better method of doing it. I’ve pasted my network below (mind you it currently uses the PatchConvolution with my custom CUDA code, but that can be swapped out easily for any of these methods we’ve evaluated). I am training with ADAM(lr=1e-4) and using the MSELoss if that makes any difference (it shouldn’t).

import torch
import torch.nn as nn
import torch.nn.functional as F

from patch_convolution import *


##-----------------------------------------------------------------------------
class Network(nn.Module):

    def __init__(self):
        # Call the super constructor
        super(Network, self).__init__()

        # 1x
        self.encoder0_0 = ConvELUBlock(in_channels=3,
                                       out_channels=32,
                                       kernel_size=3,
                                       padding=1)
        self.encoder0_1 = ConvELUBlock(in_channels=32,
                                       out_channels=64,
                                       kernel_size=3,
                                       padding=1)

        # 2x
        self.encoder1_0 = ConvELUBlock(in_channels=64,
                                       out_channels=128,
                                       kernel_size=3,
                                       padding=1,
                                       stride=2)
        self.encoder1_1 = ConvELUBlock(in_channels=128,
                                       out_channels=128,
                                       kernel_size=3,
                                       padding=1)
        self.encoder1_2 = ConvELUBlock(in_channels=128,
                                       out_channels=128,
                                       kernel_size=3,
                                       padding=1)

        # 4x
        self.encoder2_0 = ConvELUBlock(in_channels=128,
                                       out_channels=256,
                                       kernel_size=3,
                                       padding=1,
                                       stride=2)
        self.encoder2_1 = ConvELUBlock(in_channels=256,
                                       out_channels=256,
                                       kernel_size=3,
                                       padding=1)
        self.encoder2_2 = ConvELUBlock(in_channels=256,
                                       out_channels=256,
                                       kernel_size=3,
                                       padding=1)

        # 4x
        self.encoder3_0 = ConvELUBlock(in_channels=256,
                                       out_channels=512,
                                       kernel_size=3,
                                       padding=1)
        self.encoder3_1 = ConvELUBlock(in_channels=512,
                                       out_channels=512,
                                       kernel_size=3,
                                       padding=1)

        # 2x
        self.decoder0_0 = ConvTransposeELUBlock(in_channels=512,
                                                out_channels=256,
                                                kernel_size=4,
                                                padding=1,
                                                stride=2)
        self.decoder0_1 = ConvELUBlock(in_channels=256,
                                       out_channels=256,
                                       kernel_size=3,
                                       padding=1)

        # 1x
        self.decoder1_0 = ConvTransposeELUBlock(in_channels=256,
                                                out_channels=128,
                                                kernel_size=4,
                                                padding=1,
                                                stride=2)
        self.decoder1_1 = ConvELUBlock(in_channels=128,
                                       out_channels=128,
                                       kernel_size=3,
                                       padding=1)
        self.decoder1_2 = ConvELUBlock(in_channels=128,
                                       out_channels=64,
                                       kernel_size=1)

        # Prediction
        self.prediction = ConvELUBlock(in_channels=64,
                                       out_channels=1,
                                       kernel_size=3,
                                       padding=1)

        self.apply(xavier_init)

    def forward(self, x):

        encoder0_0_out = self.encoder0_0(x)
        encoder0_1_out = self.encoder0_1(encoder0_0_out)

        encoder1_0_out = self.encoder1_0(encoder0_1_out)
        encoder1_1_out = self.encoder1_1(encoder1_0_out)
        encoder1_2_out = self.encoder1_2(encoder1_1_out)

        encoder2_0_out = self.encoder2_0(encoder1_2_out)
        encoder2_1_out = self.encoder2_1(encoder2_0_out)
        encoder2_2_out = self.encoder2_2(encoder2_1_out)

        encoder3_0_out = self.encoder3_0(encoder2_2_out)
        encoder3_1_out = self.encoder3_1(encoder3_0_out)

        decoder0_0_out = self.decoder0_0(encoder3_1_out)
        decoder0_1_out = self.decoder0_1(decoder0_0_out)

        decoder1_0_out = self.decoder1_0(decoder0_1_out)
        decoder1_1_out = self.decoder1_1(decoder1_0_out)
        decoder1_2_out = self.decoder1_2(decoder1_1_out)

        pred = self.prediction(decoder1_2_out)

        return pred


##-----------------------------------------------------------------------------
class ConvELUBlock(nn.Module):
    '''Convenience function combining a convolution with an ELU activation'''

    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1):
        super(ConvELUBlock, self).__init__()

        self.conv = PatchConvolution(in_channels=in_channels,
                                     out_channels=out_channels,
                                     kernel_size=kernel_size,
                                     stride=stride,
                                     padding=padding,
                                     dilation=dilation)

    def forward(self, x):
        return F.elu(self.conv(x), inplace=True)


##-----------------------------------------------------------------------------
class ConvTransposeELUBlock(nn.Module):
    '''Convenience function combining a convolution with an ELU activation'''

    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1):
        super(ConvTransposeELUBlock, self).__init__()

        self.conv = TransposedPatchConvolution(in_channels=in_channels,
                                               out_channels=out_channels,
                                               kernel_size=kernel_size,
                                               stride=stride,
                                               padding=padding,
                                               dilation=dilation)

    def forward(self, x):
        return F.elu(self.conv(x), inplace=True)

Ho,

You actually use transposed convolutions !
You might want to benchmark these, they have a tendency to have performance pitfalls for some inputs. Let me know what it looks like for the transposed versions.

Interesting… I hadn’t considered that as it’s only two of the layers. I’ll take a look there too.

I also am curious how your benchmark is so much faster my runs on my machine. I’ve tried it repeatedly now and can’t come close. I thought the GTX 1080Ti is somewhat faster than the Quadro GP100

What are the cuda / cudnn versions you’re using?

I’m using CUDA 10.1. I have whatever cudnn version shipped with the conda torch install from the PyTorch website.

>>> import torch
>>> torch.backends.cudnn.version()
7600

I’m still working out how to do transposed convolutions with fold/unfold so I didn’t test those yet, but here’s the benchmark with my custom code vs nn.ConvTranspose2d. Definitely 2-3x slower than conv.

######################### CUDNN FALSE


Various patch convolution methods:
  Rolled-into-batch avg. time: 0.411205 seconds
  Custom-patch-im2col avg. time: 0.109755 seconds


Compare to traditional convolution with B x C x H x W inputs with similar number (actually more) of elements:
  nn.Conv2d: 0.098622 seconds
######################### CUDNN TRUE


Various patch convolution methods:
  Rolled-into-batch avg. time: 0.167382 seconds
  Custom-patch-im2col avg. time: 0.110492 seconds


Compare to traditional convolution with B x C x H x W inputs with similar number (actually more) of elements:
  nn.Conv2d: 0.074942 seconds
######################### CUDNN TRUE BENCH


Various patch convolution methods:
  Rolled-into-batch avg. time: 0.145036 seconds
  Custom-patch-im2col avg. time: 0.110708 seconds


Compare to traditional convolution with B x C x H x W inputs with similar number (actually more) of elements:
  nn.Conv2d: 0.097421 seconds
######################### CUDNN TRUE BENCH


Various patch convolution methods:
  Rolled-into-batch avg. time: 0.130080 seconds
  Custom-patch-im2col avg. time: 0.111018 seconds


Compare to traditional convolution with B x C x H x W inputs with similar number (actually more) of elements:
  nn.Conv2d: 0.074998 seconds

@albanD: So it seems that PyTorch 1.3.0 has some big speed improvements. I just upgraded my version, recompiled my custom layers, and ran the benchmark again. It’s substantially faster–and measures up closer to your runs. Though I now get a cuda_runtime_error(700): illegal memory access on the backward pass through my custom layer…which is weird because it still passed my unit tests and runs in my network :confused:.
With CUDA 10.1, PyTorch 1.3.0 and CuDNN 7603:

Forward pass of standard:

######################### CUDNN FALSE


Various patch convolution methods:
  Using-unfold avg. time: 0.030516 seconds
  Rolled-into-batch avg. time: 1.412198 seconds
  Custom-patch-im2col avg. time: 0.017160 seconds


Compare to traditional convolution with B x C x H x W inputs with similar number (actually more) of elements:
  nn.Conv2d: 0.017581 seconds
######################### CUDNN TRUE


Various patch convolution methods:
  Using-unfold avg. time: 0.019811 seconds
  Rolled-into-batch avg. time: 0.015584 seconds
  Custom-patch-im2col avg. time: 0.017159 seconds


Compare to traditional convolution with B x C x H x W inputs with similar number (actually more) of elements:
  nn.Conv2d: 0.016928 seconds
######################### CUDNN TRUE BENCH


Various patch convolution methods:
  Using-unfold avg. time: 0.019813 seconds
  Rolled-into-batch avg. time: 0.018662 seconds
  Custom-patch-im2col avg. time: 0.017177 seconds


Compare to traditional convolution with B x C x H x W inputs with similar number (actually more) of elements:
  nn.Conv2d: 0.019624 seconds
######################### CUDNN TRUE BENCH


Various patch convolution methods:
  Using-unfold avg. time: 0.019854 seconds
  Rolled-into-batch avg. time: 0.015525 seconds
  Custom-patch-im2col avg. time: 0.017159 seconds


Compare to traditional convolution with B x C x H x W inputs with similar number (actually more) of elements:
  nn.Conv2d: 0.016979 seconds

Backward pass of standard conv:

######################### CUDNN FALSE


Various patch convolution methods:
  Using-unfold avg. time: 0.016350 seconds
  Rolled-into-batch avg. time: 0.289672 seconds


Compare to traditional convolution with B x C x H x W inputs with similar number (actually more) of elements:
  nn.Conv2d: 0.017821 seconds
######################### CUDNN TRUE


Various patch convolution methods:
  Using-unfold avg. time: 0.015732 seconds
  Rolled-into-batch avg. time: 0.046712 seconds


Compare to traditional convolution with B x C x H x W inputs with similar number (actually more) of elements:
  nn.Conv2d: 0.060034 seconds
######################### CUDNN TRUE BENCH


Various patch convolution methods:
  Using-unfold avg. time: 0.015789 seconds
  Rolled-into-batch avg. time: 0.026437 seconds


Compare to traditional convolution with B x C x H x W inputs with similar number (actually more) of elements:
  nn.Conv2d: 0.023651 seconds
######################### CUDNN TRUE BENCH


Various patch convolution methods:
  Using-unfold avg. time: 0.015781 seconds
  Rolled-into-batch avg. time: 0.021847 seconds


Compare to traditional convolution with B x C x H x W inputs with similar number (actually more) of elements:
  nn.Conv2d: 0.017086 seconds

Looks like unfold may be the way to go then. Still working on that transposed convolution though…

Hi,

I would bet that the biggest speed improvement come from the upgrade to cudnn 7603 :wink:

For the illegal memory access, you want to be careful to follow the strides of the input Tensors when working with them in your custom kernel.

Reporting back on the unfold method–it takes up too much memory when running in the full network (it exhausts my GPU). This makes sense I guess as it essentially creates those “column” matrices at each layer and then has to maintain them in memory to perform backprop. When doing this on the backend with im2col/col2im the column matrix is not needed during backprop. That column matrix is C*K^2 x OH * OW which is a lot of memory compared to C * IH * IW for the input image.

And, despite the nice benchmark with PyTorch v1.3, I’m seeing the “Rolled-into-batch” approach scale terribly with the number of patches in practice. I wish I had a good explanation as to why :confused:

Alas, I’m left with the version I wrote myself. It’s doesn’t quite profile well compared to the other methods in terms of speed, but it seems to scale better at the network level. I do recognize that it probably shouldn’t, though… It’s just naive im2col + GEMM with a different striding.

Maybe @ngimel has some insights on the performance of conv with very small input image and very large batch size?

I am not sure but it looks like you want to share weights across different time stamps or you want your convolutions to become recurrent convolution. Check out how this guy implemented it : https://github.com/TsukamotoShuchi/RCNN/blob/master/rcnnblock.py.