Use conv2d with one channel results in wrong dimensions

Hi,

I want to convolve some images with some known kernels. I thought I could use torch.nn.functional.conv2d for this. So suppose I have some images with shape [2048, 1, 12, 12] (that is, a batch of 2048 images with 1 channel of size 12x12 pixels) and a kernel like kernel = torch.tensor([[-1, 1],[-1, 1]]), i.e. a 2x2 kernel. How can I convolve every image in my batch with this kernel?

I thought this could be done with torch.nn.functional.conv2d(images, kernel.expand(2048, 1, 2, 2)) but this results a final matrix of size [2048, 2048, 11, 11]. If I read the docs correctly, this is due to having the number of groups the same as the amount of out-channels, so every input is convolved with every filter.

This is not what I want. What I want is that every image is convolved with the same kernel. Also, I want the output size to match the input size, as if I am calling scipy.signal.convolve2d(image, kernel, mode='same', boundary='symm') for each image in the batch (this is what I am doing now, but this is slow because it involves many copies from GPU to CPU).

Can I do this operation with pytorch’s functional.conv2d as to avoid copying data from the GPU to CPU? If so what kind of parameters do I need to pass to match my expected result?

Thanks in advance.

I think you just need to add padding. Given that your kernel size is 2x2 (even kernel-size),I think you can add manual padding only at one side of each dimension using torch.nn.ZeroPad2d

If you just use padding=1, that will increase the output size to 13x13:

>>> import torch
>>> import torch.nn.functional as F
>>> a = torch.randn(2048, 1, 12, 12)
>>> a.shape
torch.Size([2048, 1, 12, 12])
>>> kernel = torch.randn(2048, 1, 2, 2)
>>> b = F.conv2d(a, kernel, padding=1)
>>> b.shape
torch.Size([2048, 2048, 13, 13])

But if you manually add padding, your output will be the same shape as input:

>>> import torch.nn as nn
>>> pad = nn.ZeroPad2d((1, 0, 1, 0))
>>> ap = pad(a)
>>> ap.shape
torch.Size([2048, 1, 13, 13])
>>> b = F.conv2d(ap, kernel)
>>> b.shape
torch.Size([2048, 2048, 12, 12])

Hi, thanks for you answer. I read up on how scipy.signal.convolve2d with same mode. This works indeed with padding, but is ‘centered’ with respect to ‘full’, so I guess in my case that means using the ReflectionPad2d layer and follow the same approach you suggest.

The other question still remains because the final output is still [2048, 2048, 12, 12] because every image is convolved with every mask rather than each image with each corresponding mask, which would make the final output [2048, 1, 12, 12]. Do you also have a solution for that?

Thanks!

I think I might have a solution, if I do torch.nn.functional.conv2d(pad(data).view(1, 2048, 13, 13), kernel.expand(2048, 1, 2, 2), groups=2048) I get a result of shape [1, 2048, 12, 12], where most of the data is the same as convolving with convolve2d, except for the first row, and there rows where every value is nan.

I’ll just look into that and keep you updated.

Thanks!

Yes, that’s right. I did not notice that the output size should be [2048, 1, 12, 12], so specifically 1 output channel. So the convolution kernel shape is [output_channels, input_channels, kernel_dim_0, kernel_dim_1].

So determining the kernel size:

  • Desired output channels: 1
  • Input channels: 1
  • Filter size: 2x2

Therefore, the kernel size for this should just be [1, 1, 2, 2]:

>>> a = torch.randn(2048, 1, 12, 12)
>>> pad = nn.ZeroPad2d((1, 0, 1, 0))
>>> ap = pad(a)
>>> ap.shape
torch.Size([2048, 1, 13, 13])
>>> kernel = torch.randn(1, 1, 2, 2)
>>> kernel.shape
torch.Size([1, 1, 2, 2])
>>> b = F.conv2d(ap, kernel)
>>> b.shape
torch.Size([2048, 1, 12, 12])
1 Like

Hi, thank you very much! I was able to match the result from scipy.signal.convolve2d(img, mask, mode='same', boundary='symm') with the following Pytorch code:

imgs = torch.rand(2048, 1, 12, 12, dtype=torch.float32, device='cuda')
# F.conv2d does a cross correlation so be sure to flip the mask to do a convolution
mask = torch.flip(torch.tensor([[-1, 1],[-1, 1]], dtype=torch.float32), [0, 1]).expand(1, 1, 2, 2).to('cuda')
pad = torch.nn.ReplicationPad2d((1, 0, 1, 0))
return F.conv2d(pad(imgs), mask)

Again, thank you very much!

1 Like