Problem with torch.pow()

Hello,

In my CNN model there is a MLP which has 4 fully connected layers with last one having just one node. This output I call alpha and has 2 dimensions as in this sample output of alpha [[0.1023]] . I make sure that minimum values never goes below 0.1 using alpha=torch.clamp(alpha,min=0.1,max=2)

I need now need to manipulate some other Pytorch tensor x with alpha . Both are in torch.float type. The tensor x has the usual 4 dimensions of batch x channel x height x width.

So I have tried the following:-

  • x**alpha[0][0]: This however causes Nan just after the first iteration for both alpha and x. I tried several other variants x**alpha, x.pow(alpha[0][0]) and x.pow(alpha). But all had the same problem. How to solve this?
  • x*alpha[0][0]: This however runs nicely I did not encounter any problem even after 100,000 iterations. So why raising to power is problematic while simple multiplication is alright?
  • Interestingly for x.pow(alpha.item()) Nan problem goes away but is this the correct way to do? Does this cause problem in backpropagation?

I use a very small learning rate of 1e-4 with Adam Optimiser. I am using Pytorch 1.3.1.

Thankyou verymuch

I do not know if this helps, but i found a way to circumvent the issue though I still dont understand why the original issue came up.

alpha=torch.clamp(alpha,min=0.1,max=2)
with

m = torch.nn.Threshold(0.1,0.1)
alpha=m(alpha)

Now allows me to use x**alpha[0][0]. Think there is a bug with clamp.

1 Like

Hi Mohit,

I am relatively confident, that there is no problem with clamp,.

What’s the value range of x?
If you have values of x <= 0, you might expect non-integral powers and derivatives to be problematic.

Best regards

Thomas

Thankyou for your reply.

The tensor x is between 0 to 1. With torch.min i checked it is 0. But its low light data. So average value of x is 1e-4 or 1e-3 and no greater 1e-2. rarely it would reach 1e-1.

But what ever it might why using clamp gives nan and nn.Threshold does where th job of both is to clamp any value less than 0.1 to 0.1 for alpha.