Hi, I am trying to replicate the 2D convolutions using FFT. To test my understanding and my FFT convolution, I am applying a 5x5x3 sobel filter to a 256x256x3 image.
Surprisingly, I do see some edges in the final convolved image with my FFT convolution and I believe its right. But I get an img full of noise when I try using torch’s Functional Conv2d.
Code:
The sobel filter:
kernel = np.array([[[1, 2, 0, -2, -1], [4, 8, 0, -8, -4], [6, 12, 0, -12, -6], [4, 8, 0, -8, -4], [1, 2, 0, -2, -1]],
[[1, 2, 0, -2, -1], [4, 8, 0, -8, -4], [6, 12, 0, -12, -6], [4, 8, 0, -8, -4], [1, 2, 0, -2, -1]],
[[1, 2, 0, -2, -1], [4, 8, 0, -8, -4], [6, 12, 0, -12, -6], [4, 8, 0, -8, -4], [1, 2, 0, -2, -1]]])
Functional Conv2D usage:
kernel = np.expand_dims(kernel, axis=0)
weight = nn.Parameter(torch.from_numpy(kernel.astype(np.int8)), requires_grad=False)
print(weight.shape)
# weight = torch.unsqueeze(weight, 0)
new_img = np.expand_dims(img, axis=0)
new_img = nn.Parameter(torch.from_numpy(new_img.astype(np.int8)), requires_grad=False)
print(new_img.shape)
# new_img = torch.unsqueeze(new_img, 0)
classic_2d_conv = F.conv2d(new_img, weight, bias=None, stride=1, padding=2, dilation=1, groups=1)
classic_2d_conv = classic_2d_conv.cpu().detach().numpy()
classic_2d_conv = classic_2d_conv.reshape((256, 256, 1))
-
Edit: The kernel and the image have the dimensions in the right order. I just did not paste the snippet here. I have maintained (1, 3, 256, 256) for the image and (1, 3, 5, 5) for the kernel.
-
I intend to perform “same” convolution, hence a padding of 2.
-
1 output channel, 3 input channel is intentional.
-
In FFT convolution, I am convolving each image channel, with the corresponding filter channel. Then summing the 3 outputs into 1 thus arriving at the same behaviour as mentioned in the above point.
The output images:
I just cannot figure out why the Functional API is producing the gibberish image. Please point out if I am using the functional API in the wrong way or if I am missing some obvious thing. Thanks!
(Please let me know if you need more code snippets)