Variable not backpropagating

Hi guys I’m implementing a custom pooling function where there is a trainable tensor. So far I’m using Variable to wrap the tensor and requires_grad=True. However, it seems the that the variable is not changing as the model trains, it keeps stuck at the init value of 3. What am I doing wrong?

Here is the code:

class GeneralizedMeanPooling(nn.Module):
    def __init__(self):
        super(GeneralizedMeanPooling, self).__init__()
        self.p = torch.autograd.Variable(torch.cuda.FloatTensor([3]),

    def forward(self, x):
        B, C, H, W = x.shape
        x = x.view((B, C, -1))
        x = torch.sum(x ** self.p.expand_as(x), dim=2) / (H * W)
        x = x ** (1 / self.p)
        return x



Variables don’t exist anymore so you don’t need them, you can do torch.cuda.FloatTensor([3], requires_grad=True) directly.
In nn.Module, for something to be recognized as a parameter (and thus be returned by mod.parameters()), it needs to be of type nn.Parameter(). So you should do:

self.p = torch.nn.Parameter(torch.cuda.FloatTensor([3]))
1 Like

Thanks for the help! I just changed to Parameter and now it is updating! However, for some reason now it always goes to NaN :confused:


I can see it going to nan if p goes toward zero.
I am not sure about all the derivatives for the pow operation. Is it defined for every value of p? For every value of 1/p?

It should be:

I even added a small epsilon to x = x ** (1 / (self.p + 1e-8))

I see a few log in there that could be problematic as well.
Also you can activate the anomaly detection mode to know exactly which op created the nan first !

  File "/media/data/Kaggle/Humpback_Whale_Identification/", line 57, in forward
    x = torch.sum(x ** self.p, dim=2) / (H * W)

RuntimeError: Function 'PowBackward1' returned nan values in its 1th output.

Useful thing! It seems the problem is in the first pow. Any ideas on how to fix it?

You can find here the definition of the backward for the pow operator.
And the function for the backward towards p is here. As you can see, it takes the log of the input. So I expect that if x contains any 0, then you will get a nan here. You can check that with: print(x.eq(0).sum()).

1 Like

Indeed, it worked! I added a small epsilon to x and it went fine! Thanks a lot! :grinning: