I’m trying to implement Stochastic pooling with pytorch. I’m implementing Stochastic max pooling and somehow my code doesn’t compute the gradients correctly and won’t work when propagating backward. The code can give the correct output but the gradients seem to be removed somewhere.
from torch.distributions.one_hot_categorical import OneHotCategorical as ohc
class RandomPool2d(nn.Module):
def __init__(self, kernal_size=2):
super().__init__()
self.ksize = tuple(kernal_size) if type(kernal_size) == int else kernal_size
self.p = self.ksize[0] * self.ksize[1]#self.p is the area of kernel size
def forward(self, inp: tc.Tensor):
if self.training:
bsz, chn, h, w = inp.size()
inp: tc.Tensor = tcf.unfold(inp, self.ksize, stride=self.ksize)
inp = inp.view(bsz, chn, self.p, -1).transpose_(2, 3).flatten(0, 2)
# mask=ohc(inp.softmax(1)).sample().to(dtype=tc.uint8) #this is just another way to create a one-hot represation of the indexes drawn
mask: tc.Tensor = tc.zeros_like(inp, dtype=tc.uint8)
mask.scatter_(1, tc.multinomial(inp.softmax(1), 1), 1) #generate a one-hot representation of the index drawn
inp = inp.masked_select(mask)# use the indexs to choose from orginial tensor
inp = inp.view(bsz, chn, -1, 1).transpose_(2, 3)
return inp.view(bsz, chn, h // self.ksize[0], w // self.ksize[1])
else:
return tcf.max_pool2d(inp, self.ksize)
The usage of this code is exactly like other 2d pooling layers except the stride are always equal to kernel size.
I am quite new with pytorch and is unable to find out why is the gradients not computed correctly
Can anyone help me solve this problem?