Gradients for torch.nn.functional.softplus() are numerically unstable

Hello Forum!

It appears that backward() for softplus() has a numerical issue
(that does not appear to be caused directly by a related forward()
issue) for large negative arguments.

A straightforward (but not completely naive) implementation of
softplus() – that has no explicit backward() – seems to work
fine with autograd.

Here is an illustrative test script:

import torch
print (torch.__version__)

def test_softplus (x):   # use log-sum-exp trick and loq1p
    return torch.where (x > 0, x + torch.exp (-x).log1p(), torch.exp (x).log1p())

def log1p_softplus (x):  # log1p version works well enough
    return torch.exp (x).log1p()

xpt = torch.arange (-500, 101, dtype = torch.float) / 10

relerr = ((torch.nn.functional.softplus (xpt) - test_softplus (xpt)) / test_softplus (xpt)).abs().max()

# forward looks okay
print ('forward looks okay:                   relerr =', relerr)

# look at gradients

sig = torch.sigmoid (xpt)   # sigmoid is the derivative of softplus

xtt = torch.arange (-500, 101, dtype = torch.float) / 10

xpt.requires_grad = True
torch.nn.functional.softplus (xpt).sum().backward()
xtt.requires_grad = True
test_softplus (xtt).sum().backward()

relerrpt = ((xpt.grad - sig) / sig).abs().max()
relerrtt = ((xtt.grad - sig) / sig).abs().max()

# built-in softplus-backward looks numerically unstable
print ('built-in backward looks unstable:     relerrpt =', relerrpt)
print ('backward for hand-rolled looks okay:  relerrtt =', relerrtt)

print ('sig[:5] =', sig[:5])
print ('xpt.grad[:5] =', xpt.grad[:5])
print ('xtt.grad[:5] =', xtt.grad[:5])

And here is its output:

1.7.1
forward looks okay:                   relerr = tensor(1.0752e-07)
built-in backward looks unstable:     relerrpt = tensor(1.)
backward for hand-rolled looks okay:  relerrtt = tensor(1.8833e-07)
sig[:5] = tensor([1.9287e-22, 2.1316e-22, 2.3558e-22, 2.6035e-22, 2.8774e-22])
xpt.grad[:5] = tensor([0., 0., 0., 0., 0.])
xtt.grad[:5] = tensor([1.9287e-22, 2.1316e-22, 2.3558e-22, 2.6035e-22, 2.8774e-22])

A possibly-related issue is discussed in this thread:

Best.

K. Frank

2 Likes

Great analysis @KFrank!

@ptrblck and I were chatting about this and we found that it appears to be solved in the nightlies (@ptrblck found the probable cause: Fix: Compare input against beta * threshold in softplus backwards by jbschlosser · Pull Request #56484 · pytorch/pytorch · GitHub fixing `Softplus` forward and backward discrepancy · Issue #55587 · pytorch/pytorch · GitHub).

Best regards

Thomas

1 Like

Hi Thomas!

I can confirm that this is fixed in today’s nightly, 1.9.0.dev20210504
(but not yet in my April nightly, 1.9.0.dev20210416):

1.9.0.dev20210504
forward looks okay:                   relerr = tensor(1.0752e-07)
built-in backward looks unstable:     relerrpt = tensor(1.5066e-07)
backward for hand-rolled looks okay:  relerrtt = tensor(1.8833e-07)
sig[:5] = tensor([1.9287e-22, 2.1316e-22, 2.3558e-22, 2.6035e-22, 2.8774e-22])
xpt.grad[:5] = tensor([1.9287e-22, 2.1316e-22, 2.3558e-22, 2.6035e-22, 2.8774e-22])
xtt.grad[:5] = tensor([1.9287e-22, 2.1316e-22, 2.3558e-22, 2.6035e-22, 2.8774e-22])

(The “looks unstable” is just hard-wired text; the result is correct.)

Thanks.

K. Frank

1 Like