At some point during computation, my model needs to compute the logarithm of softplus of a parameter.
Currently, I implement this via:
torch.nn.functional.softplus(theta).log()
Due to the log() call I fear that there might be issues with computational stability.
Is there a computationally more stable way of computing log-softplus?
I do not want to change my parameterization.
Thank you!
I don’t think this can be simplified. As for “stability”, you’re limited by float32 representation bounds, it is not like you’ll get biased errors or something (except inherent log() imprecision).
Ah, sorry, you’re right, I somehow assumed that lim(log(x))=0, and problem in that area. You may just want to vectorize your approach with torch.where.
I think your concern about loss of precision is a partial red herring.
Because we’re working with floating-point numbers, we don’t lose
precision when softplus (x) becomes small (say, on the order of
10**-7). We only start to lose precision when softplus (x) becomes very small and starts to denormalize and then underflows to zero
(at which point the subsequent log() will return -inf). This starts to
happen around 10**-38 (which corresponds to an x of about -90.0).
Whether this matters depends on your specific use case. But, unless
your argument to log-softplus can reasonably become as negative as,
say, -90, your “naive” implementation will be fine.
On the other hand, if your argument to log-softplus can reasonably
become that negative, then your suggested version does make
sense, except that you should use a 'breakpoint" (x > -5) that is
significantly more negative than -5 (perhaps something like -40).
As it stands, your breakpoint of -5 does more harm than good,
because in this range x is an imperfect approximation to log-softplus.
The following illustrative script uses a double-precision log-softplus
computation as a surrogate for the exact result:
I think you are right, for precise computation of log-softmax, a much lower breakpoint is sufficient.
What I did not mention (my bad) is that I want to apply gradient backprop through this computation.
This changes things a lot: If setting the breakpoint too low, the error in the gradient can be 7 orders of magnitude bigger than the error in the log-softplus output.
The following code is an extension of yours.