Hello Forum!
It appears that backward()
for softplus()
has a numerical issue
(that does not appear to be caused directly by a related forward()
issue) for large negative arguments.
A straightforward (but not completely naive) implementation of
softplus()
– that has no explicit backward()
– seems to work
fine with autograd.
Here is an illustrative test script:
import torch
print (torch.__version__)
def test_softplus (x): # use log-sum-exp trick and loq1p
return torch.where (x > 0, x + torch.exp (-x).log1p(), torch.exp (x).log1p())
def log1p_softplus (x): # log1p version works well enough
return torch.exp (x).log1p()
xpt = torch.arange (-500, 101, dtype = torch.float) / 10
relerr = ((torch.nn.functional.softplus (xpt) - test_softplus (xpt)) / test_softplus (xpt)).abs().max()
# forward looks okay
print ('forward looks okay: relerr =', relerr)
# look at gradients
sig = torch.sigmoid (xpt) # sigmoid is the derivative of softplus
xtt = torch.arange (-500, 101, dtype = torch.float) / 10
xpt.requires_grad = True
torch.nn.functional.softplus (xpt).sum().backward()
xtt.requires_grad = True
test_softplus (xtt).sum().backward()
relerrpt = ((xpt.grad - sig) / sig).abs().max()
relerrtt = ((xtt.grad - sig) / sig).abs().max()
# built-in softplus-backward looks numerically unstable
print ('built-in backward looks unstable: relerrpt =', relerrpt)
print ('backward for hand-rolled looks okay: relerrtt =', relerrtt)
print ('sig[:5] =', sig[:5])
print ('xpt.grad[:5] =', xpt.grad[:5])
print ('xtt.grad[:5] =', xtt.grad[:5])
And here is its output:
1.7.1
forward looks okay: relerr = tensor(1.0752e-07)
built-in backward looks unstable: relerrpt = tensor(1.)
backward for hand-rolled looks okay: relerrtt = tensor(1.8833e-07)
sig[:5] = tensor([1.9287e-22, 2.1316e-22, 2.3558e-22, 2.6035e-22, 2.8774e-22])
xpt.grad[:5] = tensor([0., 0., 0., 0., 0.])
xtt.grad[:5] = tensor([1.9287e-22, 2.1316e-22, 2.3558e-22, 2.6035e-22, 2.8774e-22])
A possibly-related issue is discussed in this thread:
Best.
K. Frank