Hi
I am trying to implement a wavelet-based pooling layer in one of my works. For that, I am using the wavelet toolbox written in torch.
The pooling layer is being implemented as a torch.nn module sub-class where I define the forward function as expected. I am able to run a forward pass and backward pass but the gradients I see are not making sense. Do note that the wavelet pooling layer has no learnable parameters.
Here is the code for wavelet pooling -
class WaveletPooling(nn.Module):
def __init__(self, wavelet):
super(WaveletPooling,self).__init__()
self.wavelet = wavelet
def forward(self, x):
bs = x.size()[0]
FORWARD_OUTPUT_ = []
# loop over input as batching not supported
for k in range(bs):
coefficients = ptwt.wavedec2(torch.squeeze(x[k,:,:,:]), pywt.Wavelet(self.wavelet),
level=2, mode="constant")
# 2nd order DWT
forward_output_ = ptwt.waverec2([coefficients[0], coefficients[1]], pywt.Wavelet(self.wavelet))
FORWARD_OUTPUT_.append(torch.squeeze(forward_output_, dim = 1))
FORWARD_OUTPUT_ = torch.stack(FORWARD_OUTPUT_)
return FORWARD_OUTPUT_
Here is my dummy model which just does a wavelet pooling -
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.pool = WaveletPooling('haar')
def forward(self, x):
x = self.pool(x)
return x
Here is my training driver code -
input_ = torch.ones(1,3,4,4)
input_.requires_grad = True
print('Model input and output dim ... ')
print('Input shape ---' + str(input_.shape))
m = Model()
# m.register_full_backward_hook(hook_fn)
output_ = m(input_.float())
print('Output shape ---' + str(output_.shape))
print('----------------------------------')
print('Model input and output ...')
print(input_)
print(output_)
print("----------------------------------")
(output_.mean()).backward()
On analyzing the gradients from the backward hook, I see the following -
The input 1x3x2x2
to the pooling layer is the following -
tensor([[[[0.0833, 0.0833],
[0.0833, 0.0833]],
[[0.0833, 0.0833],
[0.0833, 0.0833]],
[[0.0833, 0.0833],
[0.0833, 0.0833]]]])
This I think makes sense; it is basically the gradient of the mean function w.r.t each of the elements which is 1/12 (normalizing factor in mean given dim is 3x2x2).
The thing which I don’t quite understand is how the gradient output from the pooling layer is -
tensor([[[[0.0417, 0.0417, 0.0417, 0.0417],
[0.0417, 0.0417, 0.0417, 0.0417],
[0.0417, 0.0417, 0.0417, 0.0417],
[0.0417, 0.0417, 0.0417, 0.0417]],
[[0.0417, 0.0417, 0.0417, 0.0417],
[0.0417, 0.0417, 0.0417, 0.0417],
[0.0417, 0.0417, 0.0417, 0.0417],
[0.0417, 0.0417, 0.0417, 0.0417]],
[[0.0417, 0.0417, 0.0417, 0.0417],
[0.0417, 0.0417, 0.0417, 0.0417],
[0.0417, 0.0417, 0.0417, 0.0417],
[0.0417, 0.0417, 0.0417, 0.0417]]]])
Is this happening because the autograd function detect a dependency with all the entries in the input to this pooling layer to the output of the pooling layer?
Thanks!
Update 1 - Referring this paper from ICLR2018 - https://openreview.net/pdf?id=rkhlb8lCZ