Filtering image in pytorch

Hi
Does anyone know how to set not learned filter as part of the training in pytorch? Specifically, I would like to have a filter that preserve high freq content. Thanks!

You just need to use the conv2d functional passing your own filter
https://pytorch.org/docs/stable/nn.functional.html?highlight=conv2d#torch.nn.functional.conv2d

torch.nn.functional.conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1)

Remember pytorch uses crosscorrelation.

Thanks! And it won’t learn its parameters?
What about back-propagating through it?

It will consider the kernel in order to backprop of course. But, as it is a fixed tensor it won’t learn.
The functional is just a function which applies the convolution. The nn.Module is a class which contains learnable parameters. Those are passed to the optimizer, hence, learned.

Note: Remember that you have to define the filter as a tensor, NOT as a nn.Parameter. Otherwise it will be learned too. :slight_smile:

1 Like

take a look at kornia.filters module
https://kornia.readthedocs.io/en/latest/filters.html

2 Likes

Dear @JuanFMontesinos
I want to apply this high-pass filter to my images:
image

My model is defined as below:

class CNN(nn.Module):
    
    def __init__(self):
        super(CNN, self).__init__()
        
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 96, kernel_size=7, stride=1 , padding=3),
            nn.BatchNorm2d(96),
            nn.ReLU(),
            nn.MaxPool2d(2)
.
.
.

Could you please help me with how to do that? (not learned filter)

Can I apply the filter out of the model? for example, in dataloader?
Thank you I found the answer:

======
[Apply Filter as a part of preprocessing]

I tried this way and it’s working:

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from skimage import data

# loads gray-level camera image.
image = data.camera()
# reshaping image because it needs to have (batches, channels, height, width)
image = image.reshape(1, 1, image.shape[0], image.shape[1])
# Converting to Tensor
t_image = torch.as_tensor(image.astype(np.float32))
# Creating a 4x4 mean filter. It needs to have (batches, channels, filter height, filter width)
t_filter = torch.as_tensor(np.full((1, 1, 4, 4), 1.0 / 16.0, dtype=np.float32))
# Using F.conv2d to apply the filter
f_image = F.conv2d(t_image, t_filter)

plt.imshow(f_image.numpy().squeeze())
plt.show()
1 Like

Hi,
A layer is usually composed by a functional (a function) + weights. Here you just need to define your weights manually and call the functional. In pseudo code it’d be:

class CNN(nn.Module):
    
    def __init__(self):
        super(CNN, self).__init__()
        
        self.register_buffer('custom_conv_weight', KERNEL, persistent=True)
        self.post_conv= nn.Sequential(
            nn.BatchNorm2d(96),
            nn.ReLU(),
            nn.MaxPool2d(2))
def forward:
      x = torch.nn.functional.conv2d(INPUT,self.custom_conv_weight,stride=1 , padding=3)
      x = self.post_conv(x)
....

Here note that your kernel is 1d and your input has 3 channels, thus, you need to expand repeat the kernel 3 times to run the same filter for all the image channels.

1 Like