Gradients through pooling layers

Hi

I am trying to implement a wavelet-based pooling layer in one of my works. For that, I am using the wavelet toolbox written in torch.

The pooling layer is being implemented as a torch.nn module sub-class where I define the forward function as expected. I am able to run a forward pass and backward pass but the gradients I see are not making sense. Do note that the wavelet pooling layer has no learnable parameters.

Here is the code for wavelet pooling -

class WaveletPooling(nn.Module):
    def __init__(self, wavelet):
        super(WaveletPooling,self).__init__()
        self.wavelet = wavelet
    
    def forward(self, x):        
        bs = x.size()[0]
        FORWARD_OUTPUT_ = []
        
        # loop over input as batching not supported
        for k in range(bs):
            coefficients = ptwt.wavedec2(torch.squeeze(x[k,:,:,:]), pywt.Wavelet(self.wavelet),
                                        level=2, mode="constant")

            # 2nd order DWT
            forward_output_ = ptwt.waverec2([coefficients[0], coefficients[1]], pywt.Wavelet(self.wavelet))
            FORWARD_OUTPUT_.append(torch.squeeze(forward_output_, dim = 1))

        FORWARD_OUTPUT_ = torch.stack(FORWARD_OUTPUT_)
        return FORWARD_OUTPUT_

Here is my dummy model which just does a wavelet pooling -

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.pool = WaveletPooling('haar')
        
    def forward(self, x):
        x = self.pool(x)
        return x

Here is my training driver code -

input_ = torch.ones(1,3,4,4)
input_.requires_grad = True
print('Model input and output dim ... ')
print('Input shape ---' + str(input_.shape))

m = Model()
# m.register_full_backward_hook(hook_fn)

output_ = m(input_.float())
print('Output shape ---' + str(output_.shape))
print('----------------------------------')
print('Model input and output ...')
print(input_)
print(output_)
print("----------------------------------")

(output_.mean()).backward()

On analyzing the gradients from the backward hook, I see the following -
The input 1x3x2x2 to the pooling layer is the following -

tensor([[[[0.0833, 0.0833],
          [0.0833, 0.0833]],

         [[0.0833, 0.0833],
          [0.0833, 0.0833]],

         [[0.0833, 0.0833],
          [0.0833, 0.0833]]]])

This I think makes sense; it is basically the gradient of the mean function w.r.t each of the elements which is 1/12 (normalizing factor in mean given dim is 3x2x2).

The thing which I don’t quite understand is how the gradient output from the pooling layer is -

tensor([[[[0.0417, 0.0417, 0.0417, 0.0417],
          [0.0417, 0.0417, 0.0417, 0.0417],
          [0.0417, 0.0417, 0.0417, 0.0417],
          [0.0417, 0.0417, 0.0417, 0.0417]],

         [[0.0417, 0.0417, 0.0417, 0.0417],
          [0.0417, 0.0417, 0.0417, 0.0417],
          [0.0417, 0.0417, 0.0417, 0.0417],
          [0.0417, 0.0417, 0.0417, 0.0417]],

         [[0.0417, 0.0417, 0.0417, 0.0417],
          [0.0417, 0.0417, 0.0417, 0.0417],
          [0.0417, 0.0417, 0.0417, 0.0417],
          [0.0417, 0.0417, 0.0417, 0.0417]]]])

Is this happening because the autograd function detect a dependency with all the entries in the input to this pooling layer to the output of the pooling layer?

Thanks!

Update 1 - Referring this paper from ICLR2018 - https://openreview.net/pdf?id=rkhlb8lCZ

You could take a look at GitHub - Fraunhofer-SCAI/wavelet_pooling: Adaptive wavelet pooling for CNN in PyTorch, AISTATS 2021. It’s based on an older version of ptwt , but it should be possible it make it work with a few tweaks.

I think your gradients could be identical, because you are pooling an input of just ones. Perhaps you could try a random input. I’d expect the picture to change.

I personally would not recommend wavelet pooling for anything other than a research project. It works, but it’s based on strided convolutions internally.

I would recommend to always compare to these, too. See i.e. [1412.6806] Striving for Simplicity: The All Convolutional Net .