I am currently trying to use wavelet pooling for CNN in order to perform simple classification task.
I did decomposition by 3 levels and always performed further processing on low level sub band (LL).
While training the model I could observe that the training loss decreases fast but difference between training and validation loss is huge almost around 20-30%. Also when I try to visualise the feature maps after pooling I could see details only in first pool operation after that I don’t see any details further during the wavelet pooling.
Below is the code for both models and wavelet pooling.
import torch
import torch.nn as nn
import numpy as np
def get_wav(in_channels, pool=True):
"""wavelet decomposition using conv2d"""
harr_wav_L = 1 / np.sqrt(2) * np.ones((1, 2))
harr_wav_H = 1 / np.sqrt(2) * np.ones((1, 2))
harr_wav_H[0, 0] = -1 * harr_wav_H[0, 0]
harr_wav_LL = np.transpose(harr_wav_L) * harr_wav_L
harr_wav_LH = np.transpose(harr_wav_L) * harr_wav_H
harr_wav_HL = np.transpose(harr_wav_H) * harr_wav_L
harr_wav_HH = np.transpose(harr_wav_H) * harr_wav_H
filter_LL = torch.from_numpy(harr_wav_LL).unsqueeze(0)
filter_LH = torch.from_numpy(harr_wav_LH).unsqueeze(0)
filter_HL = torch.from_numpy(harr_wav_HL).unsqueeze(0)
filter_HH = torch.from_numpy(harr_wav_HH).unsqueeze(0)
if pool:
net = nn.Conv2d
else:
net = nn.ConvTranspose2d
LL = net(in_channels, in_channels,
kernel_size=2, stride=2, padding=0, bias=False,
groups=in_channels)
LH = net(in_channels, in_channels,
kernel_size=2, stride=2, padding=0, bias=False,
groups=in_channels)
HL = net(in_channels, in_channels,kernel_size=2, stride=2, padding=0, bias=False,
groups=in_channels)
LL.weight.requires_grad = False
LH.weight.requires_grad = False
HL.weight.requires_grad = False
HH.weight.requires_grad = False
LL.weight.data = filter_LL.float().unsqueeze(0).expand(in_channels, -1, -1, -1)
LH.weight.data = filter_LH.float().unsqueeze(0).expand(in_channels, -1, -1, -1)
HL.weight.data = filter_HL.float().unsqueeze(0).expand(in_channels, -1, -1, -1)
HH.weight.data = filter_HH.float().unsqueeze(0).expand(in_channels, -1, -1, -1)
return LL, LH, HL, HH
class WavePool(nn.Module):
def __init__(self,in_channels):
super(WavePool,self).__init__()
self.LL,self.LH,self.HL,self.HH = get_wav(in_channels)
def forward(self,x):
return self.LL(x)
import torch
import torch.nn as nn
import torch.nn.functional as F
from blocks import ConvBlock, LinearAttentionBlock, ProjectorBlock
from initialize import *
from wavepool import *
class Model(nn.Module):
def __init__(self, init='xavierUniform'):
super(AttnVGG_before, self).__init__()
self.pad = nn.ReflectionPad2d(1)
self.relu = nn.ReLU(inplace=True)
self.conv0 = nn.Conv2d(3, 3, 1, 1, 0)
self.conv1_1 = nn.Conv2d(3, 64, 3, 1, 0)
self.conv1_2 = nn.Conv2d(64, 64, 3, 1, 0)
self.pool1 = WavePool(64)
self.conv2_1 = nn.Conv2d(64, 128, 3, 1, 0)
self.conv2_2 = nn.Conv2d(128, 128, 3, 1, 0)
self.pool2 = WavePool(128)
self.conv3_1 = nn.Conv2d(128, 256, 3, 1, 0)
self.conv3_2 = nn.Conv2d(256, 256, 3, 1, 0)
self.conv3_3 = nn.Conv2d(256, 256, 3, 1, 0)
self.conv3_4 = nn.Conv2d(256, 256, 3, 1, 0)
self.pool3 = WavePool(256)
self.conv4_1 = nn.Conv2d(256, 512, 3, 1, 0)
self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1))
self.classify = nn.Linear(in_features=512, out_features=7, bias=True)
# initialize
if init == 'kaimingNormal':
weights_init_kaimingNormal(self)
elif init == 'kaimingUniform':
weights_init_kaimingUniform(self)
elif init == 'xavierNormal':
weights_init_xavierNormal(self)
elif init == 'xavierUniform':
print("xavier uniform")
weights_init_xavierUniform(self)
else:
raise NotImplementedError("Invalid type of initialization!")
def forward(self, x):
out = self.conv0(x)
out = self.relu(self.conv1_1(self.pad(out)))
out = self.relu(self.conv1_2(self.pad(out)))
LL = self.pool1(out)
out = self.relu(self.conv2_1(self.pad(LL)))
out = self.relu(self.conv2_2(self.pad(out)))
LL = self.pool2(out)
out = self.relu(self.conv3_1(self.pad(LL)))
out = self.relu(self.conv3_2(self.pad(out)))
out = self.relu(self.conv3_3(self.pad(out)))
out = self.relu(self.conv3_4(self.pad(out)))
LL = self.pool3(out)
x = self.relu(self.conv4_1(self.pad(LL)))
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.classify(x)
return x
Images
After first wavelet pooling
After second pooling
After third pooling
I am not clear what is the exact issue. Is it that the filters that I am creating is incorrect. Please help me with this as I have never tried wavelets before just trying to experiment as I have seen lot of articles stating that it has lot of advantages in digital processing applications.
Please suggest.