Torch.where() not working and returns nan value when using exp() function

With Torch(1.13.x), I’ve been trying to implement some activation functions from scratch like mish or ELU, etc. for custom activation function.

However, I get nan value of loss after about 17 epochs when I train the model.

  • dataset: official MNIST dataset from each framework
  • model architecture: simple dense network(25 layers with 500 neurons each)
  • lr: 1e-3 (I don’t want to fix this)
  • batch_size: 128
  • optimizer: Adam

Torch code:

class Mish_Implementaion(nn.Module):
  def __init__(self):
    super(Mish_Implementaion, self).__init__()
    self.__name__ = 'Mish'
  
  def forward(self, x):
    return t.where(x < -7, 0, t.where(x > 30, x, x * t.tanh(t.log(1 + t.exp(x)))))

I used:

t.autograd.set_detect_anomaly(True)

and got this error message: Function 'ExpBackward0' returned nan values in its 0th output.

I guess it’s because of exp. function getting overflow. But that’s why I used torch.where function to avoid exp() return too high of a value.

I want to add some trainable parameter here, so making this work would be important.

Any advice is really appreciated, Thanks in advance.

You might be running into this behavior so could you check if your use case might suffer from the same logic creating invalid gradients?

You are right!
I needed to regularize exp(x) according to the condition, since every single calculation matters in BP.
So I just edited my code like following:

def forward(self, x):
    mask = (x > 30)
    return mask * x + ~mask * x * t.tanh(t.log(1 + t.exp(~mask * x)))

and it’s working.

But still, from the link you mentioned, I don’t get the part that gradient of x is sum of two input gradients. Does it mean that where() function considers gradient of both tensor simultaneously regardless of condition?

Thank you for the help.