Make Custom Conv2d Layer efficient (wrt speed and memory)

I was able to make my own Conv2d layer using ‘unfold’ and following this post. I need to have full access to the conv2d operations as I will apply some custom functions at several points during the kernel operation. For this post however I am - for a start - only concerned with getting the most basic Conv2d layer up and running.

My resulting function looks as follows:

# Vanilla Convolution
def myconv2d(input, weight, bias=None, stride=(1,1), padding=(0,0), dilation=(1,1), groups=1):
    # pad image and get parameter sizes
    input = F.pad(input=input, pad= [padding[0], padding[0], padding[1], padding[1]], mode='constant', value=0)
    dh, dw = stride
    out_channels, in_channels, kh, kw = weight.shape
    batch_size = input.shape[0]

    # unfold input
    patches = input.unfold(2, kh, dh).unfold(3, kw, dw)
    h_windows = patches.shape[2]
    w_windows = patches.shape[3]
    patches = patches.expand(out_channels, *patches.shape)
    patches = patches.permute(1, 3, 4, 0, 2, 5, 6)
    patches = patches.contiguous()
    # print(patches.shape)
    # > torch.Size([batch_size, h_windows, w_windows, out_channels, in_channels, kh, kw])

    # use our filter and sum over the channels
    patches = patches * weight
    patches = patches.sum(-1).sum(-1).sum(-1)

    # add bias
    if bias is not None:
        bias = bias.expand(batch_size, h_windows, w_windows, out_channels)
        patches = patches + bias
    patches = patches.permute(0, 3, 1, 2)
    # print(patches.shape)
    # > torch.Size([batch_size, out_channels, h_windows ,w_windows])
    return patches

This function indeed produces the same results as F.conv2d (unless I’ve been missing something), however the operation is (A) terribly slow and (B) consumes a lot of memory. Interestingly, the lines with patches * weight and patches.sum(-1) quickly cause memory errors for larger input sizes and I wonder why these allocate new memory at all…? Maybe the gradients?

My question now is, if there are ways of improving the speed and memory use of my layer without going into full CUDA programming? And at the same time keeping all operations accessible. I have the feeling that a convolution layer should use more matrix multiplication than I do currently to increase efficiency, but I don’t really know where and how.

Also when browsing this forum I came across “creating a torch.autograd.Function with a custom forward and backward method”. Might this be necessary to gain efficiency? After all I am using only pytorch functions so maybe that would be overkill.

Any help is much appreciated!

Hi,

There are a few things that you can do to speed up your code, but this is expected to the slower and much more memory hungry than the conv2d module itself.
This is because the conv2d can use other conv implementation (based on fft for example) depending on the input shape to reduce memory use.

I would advise using the functional version of unfold as well from here so that you only need one call.
You can rewrite:

    patches = patches * weight
    patches = patches.sum(-1).sum(-1).sum(-1)

to be more efficient by collapsing the last 3 dimensions and doing a matrix matrix multiplications between patches and weight.transpose(-1, -2) to get the result of these two lines with a single op.

Thank you for your reply! Especially for the link and the mm info.
With this, the code got much more concise so for anyone interested, here is the better version:
I am yet to check how much more efficient it runs but I am rather hopeful.

# Vanilla Convolution
def myconv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
    batch_size, in_channels, in_h, in_w = input.shape
    out_channels, in_channels, kh, kw =  weight.shape

    unfold = torch.nn.Unfold(kernel_size=(kh, kw), dilation=dilation, padding=padding, stride=stride)
    inp_unf = unfold(input)

    if bias is None:
        out_unf = inp_unf.transpose(1, 2).matmul(weight.view(weight.size(0), -1).t()).transpose(1, 2)
    else:
        out_unf = (inp_unf.transpose(1, 2).matmul(w_) + bias).transpose(1, 2)
    out = out_unf.view(batch_size, out_channels, out_h, out_w)
    return out
1 Like

That should help a lot, especially for the backward pass :slight_smile: