Problems when implementing pooling layer

I’m trying to implement Stochastic pooling with pytorch. I’m implementing Stochastic max pooling and somehow my code doesn’t compute the gradients correctly and won’t work when propagating backward. The code can give the correct output but the gradients seem to be removed somewhere.

from torch.distributions.one_hot_categorical import OneHotCategorical as ohc

class RandomPool2d(nn.Module):
    def __init__(self, kernal_size=2):
        super().__init__()
        self.ksize = tuple(kernal_size) if type(kernal_size) == int else kernal_size
        self.p = self.ksize[0] * self.ksize[1]#self.p is the area of kernel size

    def forward(self, inp: tc.Tensor):
        if self.training:
            
            bsz, chn, h, w = inp.size()
            inp: tc.Tensor = tcf.unfold(inp, self.ksize, stride=self.ksize)
            inp = inp.view(bsz, chn, self.p, -1).transpose_(2, 3).flatten(0, 2)
            # mask=ohc(inp.softmax(1)).sample().to(dtype=tc.uint8) #this is just another way to create a one-hot represation of the indexes drawn
            mask: tc.Tensor = tc.zeros_like(inp, dtype=tc.uint8)
            mask.scatter_(1, tc.multinomial(inp.softmax(1), 1), 1) #generate a one-hot  representation of the index drawn
            inp = inp.masked_select(mask)# use the indexs to choose from orginial tensor
            inp = inp.view(bsz, chn, -1, 1).transpose_(2, 3)
            return inp.view(bsz, chn, h // self.ksize[0], w // self.ksize[1])
        else:
            return tcf.max_pool2d(inp, self.ksize)

The usage of this code is exactly like other 2d pooling layers except the stride are always equal to kernel size.
I am quite new with pytorch and is unable to find out why is the gradients not computed correctly :frowning:
Can anyone help me solve this problem?

Hi,

What do you mean by “gradients are not computed correctly”? What do you check to see that? What do you expect?

Thx 4 ur reply. the behavior was that my model wasn’t learning anything, the loss wasn’t changing a lot and the accuracy(measured per epoch) wasn’t even changing. So that’s y I thought the gradients wasn’t computed correctly.
However, I solved it already, it’s due to a problem in other part of my model; in fact, it was because I used too many Stochastic pooling and that the model was learning very slowly.

Thank you so much:smiley: