Wavelet pooling for cnn

I am currently trying to use wavelet pooling for CNN in order to perform simple classification task.
I did decomposition by 3 levels and always performed further processing on low level sub band (LL).
While training the model I could observe that the training loss decreases fast but difference between training and validation loss is huge almost around 20-30%. Also when I try to visualise the feature maps after pooling I could see details only in first pool operation after that I don’t see any details further during the wavelet pooling.

Below is the code for both models and wavelet pooling.

import torch
import torch.nn as nn
import numpy as np

def get_wav(in_channels, pool=True):
    """wavelet decomposition using conv2d"""
    harr_wav_L = 1 / np.sqrt(2) * np.ones((1, 2))
    harr_wav_H = 1 / np.sqrt(2) * np.ones((1, 2))
    harr_wav_H[0, 0] = -1 * harr_wav_H[0, 0]

    harr_wav_LL = np.transpose(harr_wav_L) * harr_wav_L
    harr_wav_LH = np.transpose(harr_wav_L) * harr_wav_H
    harr_wav_HL = np.transpose(harr_wav_H) * harr_wav_L
    harr_wav_HH = np.transpose(harr_wav_H) * harr_wav_H

    filter_LL = torch.from_numpy(harr_wav_LL).unsqueeze(0)
    filter_LH = torch.from_numpy(harr_wav_LH).unsqueeze(0)
    filter_HL = torch.from_numpy(harr_wav_HL).unsqueeze(0)
    filter_HH = torch.from_numpy(harr_wav_HH).unsqueeze(0)

    if pool:
        net = nn.Conv2d
    else:
        net = nn.ConvTranspose2d

    LL = net(in_channels, in_channels,
             kernel_size=2, stride=2, padding=0, bias=False,
             groups=in_channels)
    LH = net(in_channels, in_channels,
             kernel_size=2, stride=2, padding=0, bias=False,
             groups=in_channels)
    HL = net(in_channels, in_channels,kernel_size=2, stride=2, padding=0, bias=False,
             groups=in_channels)

    LL.weight.requires_grad = False
    LH.weight.requires_grad = False
    HL.weight.requires_grad = False
    HH.weight.requires_grad = False

    LL.weight.data = filter_LL.float().unsqueeze(0).expand(in_channels, -1, -1, -1)
    LH.weight.data = filter_LH.float().unsqueeze(0).expand(in_channels, -1, -1, -1)
    HL.weight.data = filter_HL.float().unsqueeze(0).expand(in_channels, -1, -1, -1)
    HH.weight.data = filter_HH.float().unsqueeze(0).expand(in_channels, -1, -1, -1)

    return LL, LH, HL, HH

class WavePool(nn.Module):
    def __init__(self,in_channels):
        super(WavePool,self).__init__()
        self.LL,self.LH,self.HL,self.HH = get_wav(in_channels)

    def forward(self,x):
        return self.LL(x)
import torch
import torch.nn as nn
import torch.nn.functional as F
from blocks import ConvBlock, LinearAttentionBlock, ProjectorBlock
from initialize import *
from wavepool import *

class Model(nn.Module):
    def __init__(self, init='xavierUniform'):
        super(AttnVGG_before, self).__init__()

        self.pad = nn.ReflectionPad2d(1)
        self.relu = nn.ReLU(inplace=True)

        self.conv0 = nn.Conv2d(3, 3, 1, 1, 0)
        self.conv1_1 = nn.Conv2d(3, 64, 3, 1, 0)
        self.conv1_2 = nn.Conv2d(64, 64, 3, 1, 0)
        self.pool1 = WavePool(64)

        self.conv2_1 = nn.Conv2d(64, 128, 3, 1, 0)
        self.conv2_2 = nn.Conv2d(128, 128, 3, 1, 0)
        self.pool2 = WavePool(128)

        self.conv3_1 = nn.Conv2d(128, 256, 3, 1, 0)
        self.conv3_2 = nn.Conv2d(256, 256, 3, 1, 0)
        self.conv3_3 = nn.Conv2d(256, 256, 3, 1, 0)
        self.conv3_4 = nn.Conv2d(256, 256, 3, 1, 0)
        self.pool3 = WavePool(256)

        self.conv4_1 = nn.Conv2d(256, 512, 3, 1, 0)

        self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1))
        self.classify = nn.Linear(in_features=512, out_features=7, bias=True)

        # initialize
        if init == 'kaimingNormal':
            weights_init_kaimingNormal(self)
        elif init == 'kaimingUniform':
            weights_init_kaimingUniform(self)
        elif init == 'xavierNormal':
            weights_init_xavierNormal(self)
        elif init == 'xavierUniform':
            print("xavier uniform")
            weights_init_xavierUniform(self)
        else:
            raise NotImplementedError("Invalid type of initialization!")

    def forward(self, x):
       out = self.conv0(x)
        out = self.relu(self.conv1_1(self.pad(out)))
        out = self.relu(self.conv1_2(self.pad(out)))
        LL = self.pool1(out)

        out = self.relu(self.conv2_1(self.pad(LL)))
        out = self.relu(self.conv2_2(self.pad(out)))
        LL = self.pool2(out)

        out = self.relu(self.conv3_1(self.pad(LL)))
        out = self.relu(self.conv3_2(self.pad(out)))
        out = self.relu(self.conv3_3(self.pad(out)))
        out = self.relu(self.conv3_4(self.pad(out)))
        LL = self.pool3(out)

        x = self.relu(self.conv4_1(self.pad(LL)))

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classify(x)

        return x

Images

After first wavelet pooling
act_conv_block1AM11DIS

After second pooling
act_conv_block2AM11DIS

After third pooling
act_conv_block3AM11DIS

I am not clear what is the exact issue. Is it that the filters that I am creating is incorrect. Please help me with this as I have never tried wavelets before just trying to experiment as I have seen lot of articles stating that it has lot of advantages in digital processing applications.

Please suggest.

Hi, I’m far from expert but a simple thing to try is to print the output tensor before the second pooling by applying sum() on filters so as to quick spot if there are all zeros. You can do that by using this shortened forward function instead:


> def forward(self, x):
        out = self.conv0(x)
        out = self.relu(self.conv1_1(self.pad(out)))
        out = self.relu(self.conv1_2(self.pad(out)))
        LL = self.pool1(out)
        out = self.relu(self.conv2_1(self.pad(LL)))
        out = self.relu(self.conv2_2(self.pad(out)))
        LL = self.pool2(out)
        return out, LL

Then analysis out and LL.
Hope it helps.