Currently, if x is too big( like 100), it’ll result in NaN, while it should be 0.
I defined a custom function to fix this:
class Expmexp(Function):
"""
y=e**(-e**(x)), to avoid NaN backward when x is big.
"""
@staticmethod
def forward(ctx, input):
x = input.clamp(max=80)
ctx.save_for_backward(x)
return (-x.exp()).exp_()
@staticmethod
def backward(ctx, grad_output):
"""
f' = -e**x e**(-e**x)
"""
grad_input = None
if ctx.needs_input_grad[0]:
input, = ctx.saved_tensors
ex = input.exp().neg_()
grad_input = ex.mul_(ex.exp()).mul_(grad_output)
return grad_input