Computationally stable log-softplus

At some point during computation, my model needs to compute the logarithm of softplus of a parameter.
Currently, I implement this via:

torch.nn.functional.softplus(theta).log()

Due to the log() call I fear that there might be issues with computational stability.
Is there a computationally more stable way of computing log-softplus?
I do not want to change my parameterization.
Thank you!

I don’t think this can be simplified. As for “stability”, you’re limited by float32 representation bounds, it is not like you’ll get biased errors or something (except inherent log() imprecision).

quick update:
for now, I am using the approximation as in this pseudo-code:

def log_softplus(x):
  if x > -5:
    return torch.nn.functional.softplus(x).log()
  else:
    return x

which is numerically stable for small x. If done naively (softplus then log), softplus would return a very small number with low precision.

Ah, sorry, you’re right, I somehow assumed that lim(log(x))=0, and problem in that area. You may just want to vectorize your approach with torch.where.

Hi Steve!

I think your concern about loss of precision is a partial red herring.

Because we’re working with floating-point numbers, we don’t lose
precision when softplus (x) becomes small (say, on the order of
10**-7). We only start to lose precision when softplus (x) becomes
very small and starts to denormalize and then underflows to zero
(at which point the subsequent log() will return -inf). This starts to
happen around 10**-38 (which corresponds to an x of about -90.0).

Whether this matters depends on your specific use case. But, unless
your argument to log-softplus can reasonably become as negative as,
say, -90, your “naive” implementation will be fine.

On the other hand, if your argument to log-softplus can reasonably
become that negative, then your suggested version does make
sense, except that you should use a 'breakpoint" (x > -5) that is
significantly more negative than -5 (perhaps something like -40).
As it stands, your breakpoint of -5 does more harm than good,
because in this range x is an imperfect approximation to log-softplus.

The following illustrative script uses a double-precision log-softplus
computation as a surrogate for the exact result:

import torch
print (torch.__version__)

def log_softplusA (x):
    return torch.nn.functional.softplus (x).log()

def log_softplusB (x, brk = -5.0):
    return torch.where (x > brk, torch.nn.functional.softplus (x).log(), x)

x = torch.arange (-100, 101, 1).double() / 10.0

lsp = log_softplusA (x)   # double-precision "ground truth"

diffA = log_softplusA (x.float()) - lsp
diffB = log_softplusB (x.float()) - lsp

print ('diffA.abs().max() =', diffA.abs().max())
print ('diffB.abs().max() =', diffB.abs().max())
print ('diffB.abs().argmax() =', diffB.abs().argmax())
print ('x[diffB.abs().argmax()] =', x[diffB.abs().argmax()])

print ('log_softplusA (torch.tensor ([-90.0])) =', log_softplusA (torch.tensor ([-90.0])))
print ('log_softplusA (torch.tensor ([-100.0])) =', log_softplusA (torch.tensor ([-100.0])))
print ('log_softplusA (torch.tensor ([-110.0])) =', log_softplusA (torch.tensor ([-110.0])))
print ('log_softplusA (torch.tensor ([-110.0], dtype = torch.double)) =', log_softplusA (torch.tensor ([-110.0], dtype = torch.double)))

And here is its result:

1.7.1
diffA.abs().max() = tensor(8.1795e-07, dtype=torch.float64)
diffB.abs().max() = tensor(0.0034, dtype=torch.float64)
diffB.abs().argmax() = tensor(50)
x[diffB.abs().argmax()] = tensor(-5., dtype=torch.float64)
log_softplusA (torch.tensor ([-90.0])) = tensor([-90.])
log_softplusA (torch.tensor ([-100.0])) = tensor([-99.9831])
log_softplusA (torch.tensor ([-110.0])) = tensor([-inf])
log_softplusA (torch.tensor ([-110.0], dtype = torch.double)) = tensor([-110.], dtype=torch.float64)

Best.

K. Frank

Well, thank you for your insightful comment!

I think you are right, for precise computation of log-softmax, a much lower breakpoint is sufficient.

What I did not mention (my bad) is that I want to apply gradient backprop through this computation.
This changes things a lot: If setting the breakpoint too low, the error in the gradient can be 7 orders of magnitude bigger than the error in the log-softplus output.
The following code is an extension of yours.

import torch
print (torch.__version__)

def log_softplusA (x):
    return torch.nn.functional.softplus (x).log()

def log_softplusB (x, brk = -15.0):
    return torch.where (x > brk, torch.nn.functional.softplus (x).log(), x)

x = torch.arange (-200, 100, 1).double() / 10.0
x.requires_grad = True
lsp = log_softplusA (x)   # double-precision "ground truth"
lsp.sum().backward()
lspg = x.grad

xA = torch.arange (-200, 100, 1) / 10.0
xA.requires_grad = True
lspA = log_softplusA (xA)
diffA = lspA - lsp
lspA.sum().backward()
lspgA = xA.grad
diffgA = lspgA - lspg

xB = torch.arange (-200, 100, 1) / 10.0
xB.requires_grad = True
lspB = log_softplusB (xB)
diffB = lspB - lsp
lspB.sum().backward()

lspgB = xB.grad
diffgB = lspgB - lspg


print ('diffA.abs().max() =', diffA.abs().max())
print ('diffB.abs().max() =', diffB.abs().max())
print ('diffB.abs().argmax() =', diffB.abs().argmax())
print ('x[diffB.abs().argmax()] =', x[diffB.abs().argmax()])

print ('diffgA.abs().max() =', diffgA.abs().max())
print ('diffgB.abs().max() =', diffgB.abs().max())
print ('diffgB.abs().argmax() =', diffgB.abs().argmax())
print ('x[diffgB.abs().argmax()] =', x[diffgB.abs().argmax()])
print ('lspgB[diffgB.abs().argmax()], lspg[diffgB.abs().argmax()] =', lspgB[diffgB.abs().argmax()], lspg[diffgB.abs().argmax()])

yields

1.8.1
diffA.abs().max() = tensor(8.4096e-07, dtype=torch.float64, grad_fn=<MaxBackward1>)
diffB.abs().max() = tensor(8.4096e-07, dtype=torch.float64, grad_fn=<MaxBackward1>)
diffB.abs().argmax() = tensor(61)
x[diffB.abs().argmax()] = tensor(-13.9000, dtype=torch.float64, grad_fn=<SelectBackward>)
diffgA.abs().max() = tensor(1.0000, dtype=torch.float64)
diffgB.abs().max() = tensor(0.1339, dtype=torch.float64)
diffgB.abs().argmax() = tensor(53)
x[diffgB.abs().argmax()] = tensor(-14.7000, dtype=torch.float64, grad_fn=<SelectBackward>)
lspgB[diffgB.abs().argmax()], lspg[diffgB.abs().argmax()] = tensor(0.8661) tensor(1.0000, dtype=torch.float64)

Upon variating the breakpoint and re-running the code, I empirically found the value -8.0 to perform best.

Hi Steve!

This may be related to some weirdness in pytorch’s built-in
torch.nn.functional.softplus(). Here is the relevant thread:

Best.

K. Frank

Hi Steve!

I haven’t worked through this in detail, but a first look suggests that
the softplus() fix in the recent nightlies addresses your issue.

Running the script you posted on today’s nightly, 1.9.0.dev20210504,
seems to show that the issue is gone:

1.9.0.dev20210504
diffA.abs().max() = tensor(8.4096e-07, dtype=torch.float64, grad_fn=<MaxBackward1>)
diffB.abs().max() = tensor(8.4096e-07, dtype=torch.float64, grad_fn=<MaxBackward1>)
diffB.abs().argmax() = tensor(61)
x[diffB.abs().argmax()] = tensor(-13.9000, dtype=torch.float64, grad_fn=<SelectBackward>)
diffgA.abs().max() = tensor(1.6522e-07, dtype=torch.float64)
diffgB.abs().max() = tensor(1.6522e-07, dtype=torch.float64)
diffgB.abs().argmax() = tensor(162)
x[diffgB.abs().argmax()] = tensor(-3.8000, dtype=torch.float64, grad_fn=<SelectBackward>)
lspgB[diffgB.abs().argmax()], lspg[diffgB.abs().argmax()] = tensor(0.9890) tensor(0.9890, dtype=torch.float64)

Best.

K. Frank

1 Like