Is it possible to add a trainable filter after an autoencoder?

So I’m building a denoiser with an autoencoder. The idea is that before computing my loss (after the autoencoder), I apply an empirical wiener filter to a texture map of the image and add it back to my autoencoder output (adding back ‘lost detail’). I’ve coded this filter with PyTorch.

My first attempt worked by adding the filter to the end of my autoencoder’s forward function. I can train this network and it backpropagates through my filter in training. However, if I print my network, the filter is not listed, and torchsummary doesn’t include it when calculating parameters.

This has me thinking that I am only training the autoencoder and my filter is filtering the same way every time and not learning.

Is what I’m trying to do possible?

Below is my Autoencoder:


class AutoEncoder(nn.Module):
    """Autoencoder simple implementation """
    def __init__(self):
        super(AutoEncoder, self).__init__()
        # Encoder
        # conv layer
        self.block1 = nn.Sequential(
            nn.Conv2d(1, 48, 3, padding=1),
            nn.Conv2d(48, 48, 3, padding=1),
            nn.MaxPool2d(2),
            nn.BatchNorm2d(48),
            nn.LeakyReLU(0.1)

        )
        self.block2 = nn.Sequential(
            nn.Conv2d(48, 48, 3, padding=1),
            nn.MaxPool2d(2),
            nn.BatchNorm2d(48),
            nn.LeakyReLU(0.1)
        )
        self.block3 = nn.Sequential(
            nn.Conv2d(48, 48, 3, padding=1),
            nn.ConvTranspose2d(48, 48, 2, 2, output_padding=1),
            nn.BatchNorm2d(48),
            nn.LeakyReLU(0.1)
        )
        self.block4 = nn.Sequential(
            nn.Conv2d(96, 96, 3, padding=1),
            nn.Conv2d(96, 96, 3, padding=1),
            nn.ConvTranspose2d(96, 96, 2, 2),
            nn.BatchNorm2d(96),
            nn.LeakyReLU(0.1)
        )
        self.block5 = nn.Sequential(
            nn.Conv2d(144, 96, 3, padding=1),
            nn.Conv2d(96, 96, 3, padding=1),
            nn.ConvTranspose2d(96, 96, 2, 2),
            nn.BatchNorm2d(96),
            nn.LeakyReLU(0.1)
        )
        self.block6 = nn.Sequential(
            nn.Conv2d(97, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.Conv2d(32, 1, 3, padding=1),
            nn.LeakyReLU(0.1)
        )

        # self.blockNorm = nn.Sequential(
        #     nn.BatchNorm2d(1),
        #     nn.LeakyReLU(0.1)
        # )

    def forward(self, x):
        # torch.autograd.set_detect_anomaly(True)
        # print("input: ", x.shape)
        pool1 = self.block1(x)
        # print("pool1: ", pool1.shape)
        pool2 = self.block2(pool1)
        # print("pool2: ", pool2.shape)
        pool3 = self.block2(pool2)
        # print("pool3: ", pool3.shape)
        pool4 = self.block2(pool3)
        # print("pool4: ", pool4.shape)
        pool5 = self.block2(pool4)
        # print("pool5: ", pool5.shape)
        upsample5 = self.block3(pool5)
        # print("upsample5: ", upsample5.shape)
        concat5 = torch.cat((upsample5, pool4), 1)
        # print("concat5: ", concat5.shape)
        upsample4 = self.block4(concat5)
        # print("upsample4: ", upsample4.shape)
        concat4 = torch.cat((upsample4, pool3), 1)
        # print("concat4: ", concat4.shape)
        upsample3 = self.block5(concat4)
        # print("upsample3: ", upsample3.shape)
        concat3 = torch.cat((upsample3, pool2), 1)
        # print("concat3: ", concat3.shape)
        upsample2 = self.block5(concat3)
        # print("upsample2: ", upsample2.shape)
        concat2 = torch.cat((upsample2, pool1), 1)
        # print("concat2: ", concat2.shape)
        upsample1 = self.block5(concat2)
        # print("upsample1: ", upsample1.shape)
        concat1 = torch.cat((upsample1, x), 1)
        # print("concat1: ", concat1.shape)
        output = self.block6(concat1)

        t_map = x - output

        for i in range(4):
            tensor = t_map[i, :, :, :]                 # Take each item in batch separately. Could account for this in Wiener instead
      
            tensor = torch.squeeze(tensor)              # Squeeze for Wiener input format

            tensor = wiener_3d(tensor, 0.05, 10)        # Apply Wiener with specified std and block size
            tensor = torch.unsqueeze(tensor, 0)         # unsqueeze to put back into block
            t_map[i, :, :, :] = tensor                  # put back into block

        filtered_output = output + t_map
        return filtered_output

The for loop at the end is to apply the filter to each image in the batch. I get that this isn’t parallelisable so if anyone has ideas for this, I’d appreciate it. I can post the ‘wiener 3d()’ filter function if that helps, just want to keep the post short.

I’ve tried to define a custom layer class with the filter inside it but I got lost very quickly.

Any help would be greatly appreciated!

Based on your previous post it seems your wiener filter implementation doesn’t contain any trainable parameters.
How and which parameters would you like to train?

Hi Peter,

I’ve since made the following changes to my network:

import torch
import torch.nn as nn
from wiener_3d import wiener_3d


class WienerFilter(nn.Module):
    def __init__(self, param_a=0.05, param_b=10):
        super(WienerFilter, self).__init__()
        self.param_a = nn.Parameter(torch.tensor(param_a))
        self.param_a.requires_grad = True
        self.param_b = param_b

    def forward(self, input):
        for i in range(4):
            tensor = input[i]
            tensor = torch.squeeze(tensor)
            tensor = wiener_3d(tensor, self.param_a, self.param_b)
            tensor = torch.unsqueeze(tensor, 0)
            input[i] = tensor
        return input


class AutoEncoder(nn.Module):
    """Autoencoder simple implementation """
    def __init__(self):
        super(AutoEncoder, self).__init__()
        # Encoder
        # conv layer
        self.block1 = nn.Sequential(
            nn.Conv2d(1, 48, 3, padding=1),
            nn.Conv2d(48, 48, 3, padding=1),
            nn.MaxPool2d(2),
            nn.BatchNorm2d(48),
            nn.LeakyReLU(0.1)

        )
        self.block2 = nn.Sequential(
            nn.Conv2d(48, 48, 3, padding=1),
            nn.MaxPool2d(2),
            nn.BatchNorm2d(48),
            nn.LeakyReLU(0.1)
        )
        self.block3 = nn.Sequential(
            nn.Conv2d(48, 48, 3, padding=1),
            nn.ConvTranspose2d(48, 48, 2, 2, output_padding=1),
            nn.BatchNorm2d(48),
            nn.LeakyReLU(0.1)
        )
        self.block4 = nn.Sequential(
            nn.Conv2d(96, 96, 3, padding=1),
            nn.Conv2d(96, 96, 3, padding=1),
            nn.ConvTranspose2d(96, 96, 2, 2),
            nn.BatchNorm2d(96),
            nn.LeakyReLU(0.1)
        )
        self.block5 = nn.Sequential(
            nn.Conv2d(144, 96, 3, padding=1),
            nn.Conv2d(96, 96, 3, padding=1),
            nn.ConvTranspose2d(96, 96, 2, 2),
            nn.BatchNorm2d(96),
            nn.LeakyReLU(0.1)
        )
        self.block6 = nn.Sequential(
            nn.Conv2d(97, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.Conv2d(32, 1, 3, padding=1),
            nn.LeakyReLU(0.1)
        )

        self.wiener_filter = WienerFilter()

    def forward(self, x):
        # torch.autograd.set_detect_anomaly(True)
        # print("input: ", x.shape)
        pool1 = self.block1(x)
        # print("pool1: ", pool1.shape)
        pool2 = self.block2(pool1)
        # print("pool2: ", pool2.shape)
        pool3 = self.block2(pool2)
        # print("pool3: ", pool3.shape)
        pool4 = self.block2(pool3)
        # print("pool4: ", pool4.shape)
        pool5 = self.block2(pool4)
        # print("pool5: ", pool5.shape)
        upsample5 = self.block3(pool5)
        # print("upsample5: ", upsample5.shape)
        concat5 = torch.cat((upsample5, pool4), 1)
        # print("concat5: ", concat5.shape)
        upsample4 = self.block4(concat5)
        # print("upsample4: ", upsample4.shape)
        concat4 = torch.cat((upsample4, pool3), 1)
        # print("concat4: ", concat4.shape)
        upsample3 = self.block5(concat4)
        # print("upsample3: ", upsample3.shape)
        concat3 = torch.cat((upsample3, pool2), 1)
        # print("concat3: ", concat3.shape)
        upsample2 = self.block5(concat3)
        # print("upsample2: ", upsample2.shape)
        concat2 = torch.cat((upsample2, pool1), 1)
        # print("concat2: ", concat2.shape)
        upsample1 = self.block5(concat2)
        # print("upsample1: ", upsample1.shape)
        concat1 = torch.cat((upsample1, x), 1)
        # print("concat1: ", concat1.shape)
        output = self.block6(concat1)
        t_map = x - output
        filtered_output = output + self.wiener_filter(t_map)
        return filtered_output

What I’ve done is made a wiener module and called it as the last layer of my network. I set param_a, an estimate for noise, as nn.parameter. Does this seem correct to you?

Your responses on posts across the forum have helped me a lot, thank you.

The code looks generally good. Could you print the .grad attribute of your parameter inside the WienerFilter after a backward pass and check for valid gradients?

Hi Peter sorry about the delay. After an epoch, the grad of ‘param_a’, my noise variance value looks like the following:


model.wiener_filter.param_a.grad
tensor(0.0045)
T:tensor(0.0045)
data:tensor(0.0045)
device:device(type='cpu')
dtype:torch.float32
grad:None
grad_fn:None
is_cuda:False
is_leaf:True
is_mkldnn:False
is_quantized:False
is_sparse:False
layout:torch.strided
name:None
names:()
ndim:0
output_nr:0
requires_grad:False
shape:torch.Size([])
_backward_hooks:None
_base:None
_cdata:1354585323136
_grad:None
_grad_fn:None
_version:0

The intialisation sets the tensor to 0.05 so it seems to be doing something! Thanks a lot.

Yeah, the gradient was calculated, so it looks good so far.
You could additionally check, if the actual value gets updated, but based on your code it should.