# Computationally stable log-softplus

At some point during computation, my model needs to compute the logarithm of softplus of a parameter.
Currently, I implement this via:

``````torch.nn.functional.softplus(theta).log()
``````

Due to the log() call I fear that there might be issues with computational stability.
Is there a computationally more stable way of computing log-softplus?
I do not want to change my parameterization.
Thank you!

I donâ€™t think this can be simplified. As for â€śstabilityâ€ť, youâ€™re limited by float32 representation bounds, it is not like youâ€™ll get biased errors or something (except inherent log() imprecision).

quick update:
for now, I am using the approximation as in this pseudo-code:

``````def log_softplus(x):
if x > -5:
else:
return x
``````

which is numerically stable for small x. If done naively (softplus then log), softplus would return a very small number with low precision.

Ah, sorry, youâ€™re right, I somehow assumed that lim(log(x))=0, and problem in that area. You may just want to vectorize your approach with torch.where.

Hi Steve!

I think your concern about loss of precision is a partial red herring.

Because weâ€™re working with floating-point numbers, we donâ€™t lose
precision when `softplus (x)` becomes small (say, on the order of
10**-7). We only start to lose precision when `softplus (x)` becomes
very small and starts to denormalize and then underflows to zero
(at which point the subsequent `log()` will return `-inf`). This starts to
happen around 10**-38 (which corresponds to an `x` of about -90.0).

Whether this matters depends on your specific use case. But, unless
your argument to log-softplus can reasonably become as negative as,
say, -90, your â€śnaiveâ€ť implementation will be fine.

On the other hand, if your argument to log-softplus can reasonably
become that negative, then your suggested version does make
sense, except that you should use a 'breakpoint" (`x > -5`) that is
significantly more negative than `-5` (perhaps something like -40).
As it stands, your breakpoint of `-5` does more harm than good,
because in this range `x` is an imperfect approximation to log-softplus.

The following illustrative script uses a double-precision log-softplus
computation as a surrogate for the exact result:

``````import torch
print (torch.__version__)

def log_softplusA (x):

def log_softplusB (x, brk = -5.0):

x = torch.arange (-100, 101, 1).double() / 10.0

lsp = log_softplusA (x)   # double-precision "ground truth"

diffA = log_softplusA (x.float()) - lsp
diffB = log_softplusB (x.float()) - lsp

print ('diffA.abs().max() =', diffA.abs().max())
print ('diffB.abs().max() =', diffB.abs().max())
print ('diffB.abs().argmax() =', diffB.abs().argmax())
print ('x[diffB.abs().argmax()] =', x[diffB.abs().argmax()])

print ('log_softplusA (torch.tensor ([-90.0])) =', log_softplusA (torch.tensor ([-90.0])))
print ('log_softplusA (torch.tensor ([-100.0])) =', log_softplusA (torch.tensor ([-100.0])))
print ('log_softplusA (torch.tensor ([-110.0])) =', log_softplusA (torch.tensor ([-110.0])))
print ('log_softplusA (torch.tensor ([-110.0], dtype = torch.double)) =', log_softplusA (torch.tensor ([-110.0], dtype = torch.double)))
``````

And here is its result:

``````1.7.1
diffA.abs().max() = tensor(8.1795e-07, dtype=torch.float64)
diffB.abs().max() = tensor(0.0034, dtype=torch.float64)
diffB.abs().argmax() = tensor(50)
x[diffB.abs().argmax()] = tensor(-5., dtype=torch.float64)
log_softplusA (torch.tensor ([-90.0])) = tensor([-90.])
log_softplusA (torch.tensor ([-100.0])) = tensor([-99.9831])
log_softplusA (torch.tensor ([-110.0])) = tensor([-inf])
log_softplusA (torch.tensor ([-110.0], dtype = torch.double)) = tensor([-110.], dtype=torch.float64)
``````

Best.

K. Frank

Well, thank you for your insightful comment!

I think you are right, for precise computation of log-softmax, a much lower breakpoint is sufficient.

What I did not mention (my bad) is that I want to apply gradient backprop through this computation.
This changes things a lot: If setting the breakpoint too low, the error in the gradient can be 7 orders of magnitude bigger than the error in the log-softplus output.
The following code is an extension of yours.

``````import torch
print (torch.__version__)

def log_softplusA (x):

def log_softplusB (x, brk = -15.0):

x = torch.arange (-200, 100, 1).double() / 10.0
lsp = log_softplusA (x)   # double-precision "ground truth"
lsp.sum().backward()

xA = torch.arange (-200, 100, 1) / 10.0
lspA = log_softplusA (xA)
diffA = lspA - lsp
lspA.sum().backward()
diffgA = lspgA - lspg

xB = torch.arange (-200, 100, 1) / 10.0
lspB = log_softplusB (xB)
diffB = lspB - lsp
lspB.sum().backward()

diffgB = lspgB - lspg

print ('diffA.abs().max() =', diffA.abs().max())
print ('diffB.abs().max() =', diffB.abs().max())
print ('diffB.abs().argmax() =', diffB.abs().argmax())
print ('x[diffB.abs().argmax()] =', x[diffB.abs().argmax()])

print ('diffgA.abs().max() =', diffgA.abs().max())
print ('diffgB.abs().max() =', diffgB.abs().max())
print ('diffgB.abs().argmax() =', diffgB.abs().argmax())
print ('x[diffgB.abs().argmax()] =', x[diffgB.abs().argmax()])
print ('lspgB[diffgB.abs().argmax()], lspg[diffgB.abs().argmax()] =', lspgB[diffgB.abs().argmax()], lspg[diffgB.abs().argmax()])
``````

yields

``````1.8.1
diffB.abs().argmax() = tensor(61)
diffgA.abs().max() = tensor(1.0000, dtype=torch.float64)
diffgB.abs().max() = tensor(0.1339, dtype=torch.float64)
diffgB.abs().argmax() = tensor(53)
lspgB[diffgB.abs().argmax()], lspg[diffgB.abs().argmax()] = tensor(0.8661) tensor(1.0000, dtype=torch.float64)
``````

Upon variating the breakpoint and re-running the code, I empirically found the value -8.0 to perform best.

Hi Steve!

This may be related to some weirdness in pytorchâ€™s built-in
`torch.nn.functional.softplus()`. Here is the relevant thread:

Best.

K. Frank

Hi Steve!

I havenâ€™t worked through this in detail, but a first look suggests that
the `softplus()` fix in the recent nightlies addresses your issue.

Running the script you posted on todayâ€™s nightly, 1.9.0.dev20210504,
seems to show that the issue is gone:

``````1.9.0.dev20210504
diffB.abs().argmax() = tensor(61)