Imperfect gradient for torch.sigmoid()

Hi All!

The gradient computation for torch.sigmoid() is not as good as it
could or should be.

Specifically, it underflows to zero sooner than it should.

This can be seen by direct inspection of the gradient values. Also,
the gradient should be symmetric around zero, and this condition
is violated. Notably, the gradient underflows early for positive values
of x.

Also, the standard formula, sigmoid (x) = 1 / (1 + exp (-x)),
performs better, although still imperfectly, in that its gradient doesn’t
underflow as soon as that of torch.sigmoid(). (With the standard
formula, the gradient underflows early for negative values of x.)

This is illustrated by the following script and its output:

import torch
print (torch.__version__)

_ = torch.manual_seed (2022)

def sigmoidB (x):
    return 1.0 / (1.0 + torch.exp (-x))

t = torch.arange (-80, 85, 10).float()
t.requires_grad = True

sigA = torch.sigmoid (t)
sigA.sum().backward()
grdA = t.grad

t.grad = None
sigB = sigmoidB (t)
sigB.sum().backward()
grdB = t.grad

print ('t:', t)
print ('sigA:', sigA)
print ('sigB:', sigB)
print ('grdA:', grdA)
print ('grdB:', grdB)
print ('relative asymmetry in gradient:')
print ('(grdA - grdA.flip (0)) / torch.max (grdA, grdA.flip (0)):', (grdA - grdA.flip (0)) / torch.max (grdA, grdA.flip (0)))
print ('(grdB - grdB.flip (0)) / torch.max (grdB, grdB.flip (0)):', (grdB - grdB.flip (0)) / torch.max (grdB, grdB.flip (0)))
1.10.0
t: tensor([-80., -70., -60., -50., -40., -30., -20., -10.,   0.,  10.,  20.,  30.,
         40.,  50.,  60.,  70.,  80.], requires_grad=True)
sigA: tensor([1.8049e-35, 3.9754e-31, 8.7565e-27, 1.9287e-22, 4.2484e-18, 9.3576e-14,
        2.0612e-09, 4.5398e-05, 5.0000e-01, 9.9995e-01, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00],
       grad_fn=<SigmoidBackward0>)
sigB: tensor([1.8049e-35, 3.9754e-31, 8.7565e-27, 1.9287e-22, 4.2484e-18, 9.3576e-14,
        2.0612e-09, 4.5398e-05, 5.0000e-01, 9.9995e-01, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00],
       grad_fn=<MulBackward0>)
grdA: tensor([1.8049e-35, 3.9754e-31, 8.7565e-27, 1.9287e-22, 4.2484e-18, 9.3576e-14,
        2.0612e-09, 4.5396e-05, 2.5000e-01, 4.5417e-05, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00])
grdB: tensor([0.0000e+00, 0.0000e+00, 0.0000e+00, 1.9616e-22, 4.2484e-18, 9.3576e-14,
        2.0612e-09, 4.5396e-05, 2.5000e-01, 4.5396e-05, 2.0612e-09, 9.3576e-14,
        4.2484e-18, 1.9287e-22, 8.7565e-27, 3.9754e-31, 1.8049e-35])
relative asymmetry in gradient:
(grdA - grdA.flip (0)) / torch.max (grdA, grdA.flip (0)): tensor([ 1.0000e+00,  1.0000e+00,  1.0000e+00,  1.0000e+00,  1.0000e+00,
         1.0000e+00,  1.0000e+00, -4.5947e-04,  0.0000e+00,  4.5947e-04,
        -1.0000e+00, -1.0000e+00, -1.0000e+00, -1.0000e+00, -1.0000e+00,
        -1.0000e+00, -1.0000e+00])
(grdB - grdB.flip (0)) / torch.max (grdB, grdB.flip (0)): tensor([-1.0000e+00, -1.0000e+00, -1.0000e+00,  1.6765e-02,  0.0000e+00,
         7.2414e-08,  0.0000e+00,  1.6028e-07,  0.0000e+00, -1.6028e-07,
         0.0000e+00, -7.2414e-08,  0.0000e+00, -1.6765e-02,  1.0000e+00,
         1.0000e+00,  1.0000e+00])

(The current nightly build, version 1.11.0.dev20220123, yields the
same result, as does performing the computation on the gpu, and
performing the computation in double precision yields an equivalent
result.)

See Yaroslav’s “Custom Sigmoid” thread for some of the motivation
behind this post.

Best.

K. Frank

2 Likes

CC @albanD for visibility.