How to calculate gradient of y=(-x.exp()).exp()

Currently, if x is too big( like 100), it’ll result in NaN, while it should be 0.

I defined a custom function to fix this:

class Expmexp(Function):
    y=e**(-e**(x)), to avoid NaN backward when x is big.

    def forward(ctx, input):
        x = input.clamp(max=80)
        return (-x.exp()).exp_()

    def backward(ctx, grad_output):
        f' = -e**x e**(-e**x)
        grad_input = None
        if ctx.needs_input_grad[0]:
            input, = ctx.saved_tensors
            ex = input.exp().neg_()
            grad_input = ex.mul_(ex.exp()).mul_(grad_output)
        return grad_input