At some point during computation, my model needs to compute the logarithm of softplus of a parameter.
Currently, I implement this via:
Due to the log() call I fear that there might be issues with computational stability.
Is there a computationally more stable way of computing log-softplus?
I do not want to change my parameterization.
I think your concern about loss of precision is a partial red herring.
Because we’re working with floating-point numbers, we don’t lose
precision when softplus (x) becomes small (say, on the order of
10**-7). We only start to lose precision when softplus (x) becomes very small and starts to denormalize and then underflows to zero
(at which point the subsequent log() will return -inf). This starts to
happen around 10**-38 (which corresponds to an x of about -90.0).
Whether this matters depends on your specific use case. But, unless
your argument to log-softplus can reasonably become as negative as,
say, -90, your “naive” implementation will be fine.
On the other hand, if your argument to log-softplus can reasonably
become that negative, then your suggested version does make
sense, except that you should use a 'breakpoint" (x > -5) that is
significantly more negative than -5 (perhaps something like -40).
As it stands, your breakpoint of -5 does more harm than good,
because in this range x is an imperfect approximation to log-softplus.
The following illustrative script uses a double-precision log-softplus
computation as a surrogate for the exact result:
I think you are right, for precise computation of log-softmax, a much lower breakpoint is sufficient.
What I did not mention (my bad) is that I want to apply gradient backprop through this computation.
This changes things a lot: If setting the breakpoint too low, the error in the gradient can be 7 orders of magnitude bigger than the error in the log-softplus output.
The following code is an extension of yours.