How to Implement a convolutional layer

Hello all,
For my research, I’m required to implement a convolution-like layer i.e something that slides over some input (assume 1D for simplicity), performs some operation and generates basically an output feature map. While this is perfectly similar to regular convolution, the difference here is the operation being performed - its not regular convolution. I looked through the PyTorch code on GitHub but it seems to rely on some sort of convolution primitive…
Is there a straightforward way to implement this?

Thanks

2 Likes

You could use unfold as descibed here to create the patches, which would be used in the convolution.
Instead of a multiplication and summation you could apply your custom operation on each patch and reshape the output to the desired shape.

1 Like

Thanks @ptrblck, that definitely seems to be what I’m looking for. However, I’m having a bit of a strange time understanding exactly how it works. For simplicity, assuming my data was 1D of the form (N,C,L) where N is the batch size (100, for example), C is the number of channels (1 in this case) and L is the length of the series (say 10). I want to break this into windows, each of length 5. In my understanding, I would do this as:

data.unfold(2,5,1) # 2 because I'm unfolding the L dimension, 5 for the window size, and 1 for the step

However, the output I’m getting has a shape of

[100, 1, 6, 5]

Which I have no idea how to interpret at all. On a related note, assuming I wanted to convolve this against a filter also of length 5, how would I go about it? If the filter itself was to have trainable weights (just like traditional convolution), then I would have to declare the filter as being a Tensor, right?

Sorry for the all the questions, but I have difficulty understanding some of this stuff…Thanks again for the help. I really appreciate it.

Your current code snippet will create patches with the length of 5 samples and a stride of 1.
Since dim2 has 10 values, you will end up with 6 patches.
Simplified explanation

data = torch.arange(10).view(1, 1, 10).float()
patches = data.unfold(2, 5, 1)
print(patches)
> tensor([[[[0., 1., 2., 3., 4.],
          [1., 2., 3., 4., 5.],
          [2., 3., 4., 5., 6.],
          [3., 4., 5., 6., 7.],
          [4., 5., 6., 7., 8.],
          [5., 6., 7., 8., 9.]]]])

You could define it as a tensor with requires_grad=True or directly as nn.Parameter(torch.randn(...)) and pass it to your optimizer.

Hello @ptrblck,
Many thanks, this has cleared it up beautifully for me. I really appreciate it.

Thanks again!

1 Like

Hello @ptrblck,
I feel really weird asking you this…but could I trouble you to please provide an implementation of a simple 1D CNN using the unfold() method as you’ve described it. I seem to have difficulty grasping how to wrangle using multiple trainable filters…
I’d really appreciate it if you could do this for me…
Sorry for the trouble again.
Regards

Sure!
Here you can find a manual 2D conv implementation.
The 1D case should be straightforward by removing a spatial dimension:

batch_size = 2
channels = 5
h, w = 12, 12
image = torch.randn(batch_size, channels, h, w) # input image

kh, kw = 3, 3 # kernel size
dh, dw = 3, 3 # stride

# Create conv
conv = nn.Conv2d(5, 7, (kh, kw), stride=(dh, dw), bias=False)
filt = conv.weight

# Manual approach
patches = image.unfold(2, kh, dh).unfold(3, kw, dw)
print(patches.shape) # batch_size, channels, h_windows, w_windows, kh, kw

patches = patches.contiguous().view(batch_size, channels, -1, kh, kw)
print(patches.shape) # batch_size, channels, windows, kh, kw

nb_windows = patches.size(2)

# Now we have to shift the windows into the batch dimension.
# Maybe there is another way without .permute, but this should work
patches = patches.permute(0, 2, 1, 3, 4)
print(patches.shape) # batch_size, nb_windows, channels, kh, kw

# Calculate the conv operation manually
res = (patches.unsqueeze(2) * filt.unsqueeze(0).unsqueeze(1)).sum([3, 4, 5])
print(res.shape) # batch_size, output_pixels, out_channels
res = res.permute(0, 2, 1) # batch_size, out_channels, output_pixels
# assuming h = w
h = w = int(res.size(2)**0.5)
res = res.view(batch_size, -1, h, w)

# Module approach
out = conv(image)


print('max abs error ', (out - res).abs().max())
> max abs error  tensor(3.5763e-07, grad_fn=<MaxBackward1>)
3 Likes

Hello @ptrblck,
Thank you so much for this! I was able to modify the code appropriately to suit my 1D case and it seems to work quite nicely. Many thanks for your help. Really appreciate it.

Best Regards

@ptrblck could you please tell how Conv2d is implemented internally in Pytorch? Like in Tensorflow it is implemented as matrix multiplication internally.
Thanks

The convolution calls are dispatched in aten/src/ATen/native/Convolution.cpp. The native implementation uses the im2col transformation and then a matrix multiplication while it can also be dispatched to e.g. cudnn and MKL, which can use different algorithms.

1 Like