How to calculate gradient of y=(-x.exp()).exp()

Currently, if x is too big( like 100), it’ll result in NaN, while it should be 0.

I defined a custom function to fix this:

class Expmexp(Function):
    """
    y=e**(-e**(x)), to avoid NaN backward when x is big.
    """

    @staticmethod
    def forward(ctx, input):
        x = input.clamp(max=80)
        ctx.save_for_backward(x)
        return (-x.exp()).exp_()

    @staticmethod
    def backward(ctx, grad_output):
        """
        f' = -e**x e**(-e**x)
        """
        grad_input = None
        if ctx.needs_input_grad[0]:
            input, = ctx.saved_tensors
            ex = input.exp().neg_()
            grad_input = ex.mul_(ex.exp()).mul_(grad_output)
        return grad_input