Why is this simple linear regression with only two variables so hard to converge during gradient descent?

In short, I was working on some problems whose most degenerate forms can be linear. Hence I was able to reduce the non-converging cases to a very small linear regression problem that converges unreasonably slow with gradient descent.

I was under the impression that while solving linear optimization with gradient descent is not the most efficient way, it should nonetheless converge quite quickly and be a practical way to solve linear problems (so that non-linearities can be seamlessly added later). Among other things, linear regression is considered a standard introductory problem to gradient descent. Also many NNs are piece-wise linear. Now instead, I start to question the nature of my reality.

The problem is to minimize ||Ax-B||^2 (that is to solve Ax=B) like follows.
The loss starts at 100 and is expected to minimize to 0. Instead it converged impractically slow to be solvable with gradient descent.

import torch as t

A = t.tensor([
    [-2.4969e+02, -4.1511e+00],
    [-4.1511e+00, -2.0755e-01]])

B = t.tensor([-0., 10.])

#trivially solvable by lstsq
x_solved = t.linalg.lstsq(A,B)
print(x_solved)
#solution=tensor([  1.2000, -72.1824])
print("check if Ax=B", A@x_solved.solution-B)

def forward(x_):
    return (A@x_-B).pow(2).sum()

#sanity check with the lstsq solution
print("loss computed with the lstsq solution",forward(x_solved.solution))

x = t.zeros(2,requires_grad=True)
#learning_rate = 1e-7 #converging to 99.20282745361328 at T=1000000
#learning_rate = 1e-6 #converging to 92.60104370117188 at T=1000000
learning_rate = 1e-5 #converging to 46.44608688354492 at T=1000000
#learning_rate = 1.603e-5 # converging to 29.044937133789062 at T=1000000
#learning_rate = 1.604e-5 # diverging
#learning_rate = 1.605e-5 # inf
#learning_rate = 1.61e-5 # NaN
for T in range(1000001):
    loss = forward(x)
    if T % 100 == 0:
        print(T, loss.item(),end='\r')
    loss.backward()
    with t.no_grad():
        x -= learning_rate * x.grad
        x.grad = None
print('converging to',loss.item(),f'at T={T} with lr={learning_rate}')

I have already gone to extra lengths finding a good learning rate - for normal “tuning” one would only try values such as 1e-5 or 2e-6 rather than pinning down multiple digits just below the point of divergence.
I have also tried unrolling the expression and ultimately computing the derivatives symbolically, which seemed to suggest that the pytorch grad was correct - it would have been hard to imagine that pytorch today still has a bug manifesting in such a simple case anyway. On the other hand it really baffles me if mathematically gradient descent indeed has such a weakness. Not yet exhaustively, but none of the optimizers from torch.optim worked for me either.

Did anyone know what I have encountered?

My understand here is that the problem is that you are using a fixed learning rate, which has a linear convergence rate. Using backtracking line search could improve things but, as you noticed, vanilla gradient descent is not the best way to solve such a problem… A quasi newton-method, such as L-BFGS could do a much better job… I modified your code to try it and it converges in less than 100 iterations

import torch as t

A = t.tensor([
    [-2.4969e+02, -4.1511e+00],
    [-4.1511e+00, -2.0755e-01]])

B = t.tensor([-0., 10.])

#trivially solvable by lstsq
x_solved = t.linalg.lstsq(A,B)
print(x_solved)
#solution=tensor([  1.2000, -72.1824])
print("check if Ax=B", A@x_solved.solution-B)

def forward(x_):
    return (A@x_-B).pow(2).sum()

#sanity check with the lstsq solution
print("loss computed with the lstsq solution",forward(x_solved.solution))

x = t.nn.Parameter(t.zeros(2))

learning_rate = 1

optimizer = t.optim.LBFGS([x], lr=learning_rate)
for T in range(51):
    optimizer.zero_grad()
    loss = forward(x)
    if T % 1 == 0:
        print(T, loss.item(),end='\r')
    loss.backward()
    optimizer.step(lambda: forward(x))  # performs LBFGS

print('converging to', loss.item(), f'at T={T} with lr={learning_rate}')

with t.no_grad():
    assert t.allclose(x,x_solved.solution, atol=1e-4)

Hi L!

Your experience is not surprising, given the following:

Your matrix A is modestly ill-conditioned, with a condition number of
about 2000. This doesn’t prevent plain-vanilla “stochastic” gradient
descent (SGD) from working, but does slow it down a lot.

Your matrix A corresponds to a loss surface that is a skinny bowl whose
contour lines are ellipses whose major axes lie approximately in the y
(your x[1]) direction (and whose minor axes lie, correspondingly, in the
x (your x[0]) direction). Think of this as a skinny valley running up and
down in approximately the y direction.

(These directions are given by the eigenvectors of A and are computed
by the script given below.)

As you move back and forth across this skinny valley in the x direction,
you are going up and down the steep sides of the valley, so the gradients
with respect to the x direction are large. These large gradients limit
your learning rate. But when you move in the y direction, you are, so
to speak, moving along the gently-sloping floor of the valley, so your
y gradients are small. For a learning rate limited by the x gradients,
the much smaller y gradients push you only very slowly towards the
value of y that minimizes your loss function.

You could explore this yourself – with some effort – by plotting the
contour lines of your loss function, together with the path taken by
your x-y point as your gradient-descent iteration progresses.

The larger eigenvalue (that corresponds to the approximately x-direction
eigenvalue) is about 2000 times larger than the smaller y-direction
eigenvalue. Figuratively speaking, this means that the steep walls of
your valley are about 2000 times as steep as the shallow walls.

If you try to speed up your y-direction convergence by increasing your
learning rate, first you will first start jumping back and forth across the
valley (in the x direction), which slows down x-direction convergence
(still with only slow y-direction convergence), and if you increase the
learning rate further, you will jump back and forth across the valley
with bigger and bigger steps each time in a diverging fashion.

The following script does four things: It computes the eigenstructure
of A; it tracks the gradients and values as your loop runs; it runs
gradient descent (using pytorch’s SGD optimizer), but with separate
learning rates in the x and y directions and momentum thrown in;
and it runs the Adam optimizer, which in some sense “automatically”
accounts for the difference in scale between the x and y directions.

The two-separate-learning-rates approach provides the most satisfactory
result, but, of course, required analyzing the problem to see what was
going on and adjusting the two learning rates accordingly – something
that wouldn’t be practical in a large problem with lots of parameters.

Note that Adam, while it converges reasonably efficiently when run on
your example, can be unstable, and, indeed, shows some “glitches” in
the process of converging for your problem.

Here is the script:

import torch
print (torch.__version__)

import torch as t

A = t.tensor([
    [-2.4969e+02, -4.1511e+00],
    [-4.1511e+00, -2.0755e-01]])

B = t.tensor([-0., 10.])

#trivially solvable by lstsq
x_solved = t.linalg.lstsq(A,B)
print(x_solved)
#solution=tensor([  1.2000, -72.1824])
print("check if Ax=B", A@x_solved.solution-B)

def forward(x_):
    return (A@x_-B).pow(2).sum()

#sanity check with the lstsq solution
print("loss computed with the lstsq solution",forward(x_solved.solution))

x = t.zeros(2,requires_grad=True)
#learning_rate = 1e-7 #converging to 99.20282745361328 at T=1000000
#learning_rate = 1e-6 #converging to 92.60104370117188 at T=1000000
learning_rate = 1e-5 #converging to 46.44608688354492 at T=1000000
#learning_rate = 1.603e-5 # converging to 29.044937133789062 at T=1000000
#learning_rate = 1.604e-5 # diverging
#learning_rate = 1.605e-5 # inf
#learning_rate = 1.61e-5 # NaN
# for T in range(1000001):

print ('==========')
print ('original training loop (1001 iterations) ...')
for T in range(1001):
    loss = forward(x)
    if T % 100 == 0:
        print(T, loss.item(),end='\r')
    loss.backward()
    with t.no_grad():
        x -= learning_rate * x.grad
        x.grad = None
print('converging to',loss.item(),f'at T={T} with lr={learning_rate}')

print ('==========')
print ('condition number and eigendecomposition of matrix ...')
print ('A:')
print (A)
evals, evecs = torch.linalg.eigh (A)
print ('condition number: ', evals[0] / evals[1])
print ('eigenvectors:')
print (evecs)
print ('eigenvalues: ', evals)

with torch.no_grad():
    x.zero_()
    x.grad = None

print ('==========')
print ('look at gradients and values as original loop progresses (10 iterations) ...')
print ('learning_rate:   ', learning_rate)
for  _ in range (10):
    x.grad = None
    loss = forward (x)
    loss.backward()
    print (x.grad, x)
    with torch.no_grad():
        x -= learning_rate * x.grad
print (x)

x0 = torch.zeros (1, requires_grad = True)
x1 = torch.zeros (1, requires_grad = True)
lr0 = 1.e-5
lr1 = 0.05

opt = torch.optim.SGD ([{'params': (x0,), 'lr': lr0}, {'params': (x1,), 'lr': lr1}], momentum = 0.95)

print ('==========')
print ('SGD with two learning rates and momentum (1001 iterations) ...')
for  i in range (1001):
    opt.zero_grad()
    loss = forward (torch.cat ((x0, x1)))
    loss.backward()
    if  i % 100 == 0:  print (x0.grad, x0, x1.grad, x1)
    opt.step()
print (x0, x1)

opt = torch.optim.Adam ((x,), lr = 0.1)

with torch.no_grad():
    x.zero_()

print ('==========')
print ('Adam (50001 iterations) ...')
for  i in range (50001):
    opt.zero_grad()
    loss = forward (x)
    loss.backward()
    if  i % 1000 == 0:  print (x.grad, x)
    opt.step()
print (x)

And here is its output:

2.4.0
torch.return_types.linalg_lstsq(
solution=tensor([  1.2000, -72.1824]),
residuals=tensor([]),
rank=tensor(2),
singular_values=tensor([]))
check if Ax=B tensor([0.0000e+00, 9.5367e-07])
loss computed with the lstsq solution tensor(9.0949e-13)
==========
original training loop (1001 iterations) ...
converging to 99.89566802978516 at T=1000 with lr=1e-05
==========
condition number and eigendecomposition of matrix ...
A:
tensor([[-2.4969e+02, -4.1511e+00],
        [-4.1511e+00, -2.0755e-01]])
condition number:  tensor(1803.3197)
eigenvectors:
tensor([[-0.9999,  0.0166],
        [-0.0166, -0.9999]])
eigenvalues:  tensor([-2.4976e+02, -1.3850e-01])
==========
look at gradients and values as original loop progresses (10 iterations) ...
learning_rate:    1e-05
tensor([83.0220,  4.1510]) tensor([0., 0.], requires_grad=True)
tensor([-20.6130,   2.4271]) tensor([-8.3022e-04, -4.1510e-05], requires_grad=True)
tensor([5.0461, 2.8539]) tensor([-6.2409e-04, -6.5781e-05], requires_grad=True)
tensor([-1.3069,  2.7483]) tensor([-6.7455e-04, -9.4320e-05], requires_grad=True)
tensor([0.2661, 2.7744]) tensor([-0.0007, -0.0001], requires_grad=True)
tensor([-0.1233,  2.7679]) tensor([-0.0007, -0.0001], requires_grad=True)
tensor([-0.0269,  2.7695]) tensor([-0.0007, -0.0002], requires_grad=True)
tensor([-0.0508,  2.7691]) tensor([-0.0007, -0.0002], requires_grad=True)
tensor([-0.0449,  2.7692]) tensor([-0.0007, -0.0002], requires_grad=True)
tensor([-0.0464,  2.7692]) tensor([-0.0007, -0.0003], requires_grad=True)
tensor([-0.0007, -0.0003], requires_grad=True)
==========
SGD with two learning rates and momentum (1001 iterations) ...
tensor([83.0220]) tensor([0.], requires_grad=True) tensor([4.1510]) tensor([0.], requires_grad=True)
tensor([11.5928]) tensor([1.0018], requires_grad=True) tensor([0.6503]) tensor([-60.2610], requires_grad=True)
tensor([0.0211]) tensor([1.1988], requires_grad=True) tensor([0.0031]) tensor([-72.1111], requires_grad=True)
tensor([-0.0472]) tensor([1.2012], requires_grad=True) tensor([-0.0034]) tensor([-72.2500], requires_grad=True)
tensor([-0.0647]) tensor([1.2001], requires_grad=True) tensor([-0.0012]) tensor([-72.1857], requires_grad=True)
tensor([0.0461]) tensor([1.2000], requires_grad=True) tensor([0.0008]) tensor([-72.1821], requires_grad=True)
tensor([-7.9176e-06]) tensor([1.2000], requires_grad=True) tensor([-3.9587e-07]) tensor([-72.1824], requires_grad=True)
tensor([-3.9588e-05]) tensor([1.2000], requires_grad=True) tensor([-1.9794e-06]) tensor([-72.1824], requires_grad=True)
tensor([-3.9588e-05]) tensor([1.2000], requires_grad=True) tensor([-1.9794e-06]) tensor([-72.1824], requires_grad=True)
tensor([-3.9588e-05]) tensor([1.2000], requires_grad=True) tensor([-1.9794e-06]) tensor([-72.1824], requires_grad=True)
tensor([-3.9588e-05]) tensor([1.2000], requires_grad=True) tensor([-1.9794e-06]) tensor([-72.1824], requires_grad=True)
tensor([1.2000], requires_grad=True) tensor([-72.1824], requires_grad=True)
==========
Adam (50001 iterations) ...
tensor([83.0220,  4.1510]) tensor([0., 0.], requires_grad=True)
tensor([-2.1413,  2.2022]) tensor([  0.2300, -13.8671], requires_grad=True)
tensor([-58.6907,   0.3830]) tensor([  0.6104, -36.7625], requires_grad=True)
tensor([-1.2870,  0.7021]) tensor([  0.8864, -53.3280], requires_grad=True)
tensor([-0.0639,  0.3772]) tensor([  1.0360, -62.3241], requires_grad=True)
tensor([-0.2728,  0.1918]) tensor([  1.1149, -67.0651], requires_grad=True)
tensor([-4.4379,  0.0278]) tensor([  1.1559, -69.5333], requires_grad=True)
tensor([-0.4385,  0.0452]) tensor([  1.1773, -70.8142], requires_grad=True)
tensor([-153.7787,   -2.5309]) tensor([  1.1871, -71.4768], requires_grad=True)
tensor([-100.6192,   -1.6598]) tensor([  1.1932, -71.8185], requires_grad=True)
tensor([0.7502, 0.0197]) tensor([  1.1969, -71.9940], requires_grad=True)
tensor([0.0815, 0.0051]) tensor([  1.1984, -72.0850], requires_grad=True)
tensor([-59.0743,  -0.9807]) tensor([  1.1987, -72.1325], requires_grad=True)
tensor([9.1612, 0.1534]) tensor([  1.1997, -72.1564], requires_grad=True)
tensor([-9.0396, -0.1499]) tensor([  1.1997, -72.1691], requires_grad=True)
tensor([736.3341,  12.2488]) tensor([  1.2057, -72.1697], requires_grad=True)
tensor([-0.1483, -0.0023]) tensor([  1.2000, -72.1788], requires_grad=True)
tensor([-2.1779, -0.0362]) tensor([  1.2000, -72.1806], requires_grad=True)
tensor([0.0472, 0.0008]) tensor([  1.2000, -72.1811], requires_grad=True)
tensor([-0.4259, -0.0071]) tensor([  1.2000, -72.1816], requires_grad=True)
tensor([-13.4144,  -0.2231]) tensor([  1.1999, -72.1820], requires_grad=True)
tensor([-0.3807, -0.0063]) tensor([  1.2000, -72.1820], requires_grad=True)
tensor([0.0460, 0.0008]) tensor([  1.2000, -72.1822], requires_grad=True)
tensor([2.3753e-04, 1.1876e-05]) tensor([  1.2000, -72.1822], requires_grad=True)
tensor([-0.0608, -0.0010]) tensor([  1.2000, -72.1823], requires_grad=True)
tensor([-0.0151, -0.0002]) tensor([  1.2000, -72.1823], requires_grad=True)
tensor([1.4252e-04, 7.1257e-06]) tensor([  1.2000, -72.1823], requires_grad=True)
tensor([1.5835e-04, 7.9174e-06]) tensor([  1.2000, -72.1823], requires_grad=True)
tensor([1.5835e-04, 7.9174e-06]) tensor([  1.2000, -72.1823], requires_grad=True)
tensor([1.5835e-04, 7.9174e-06]) tensor([  1.2000, -72.1823], requires_grad=True)
tensor([7.9176e-05, 3.9587e-06]) tensor([  1.2000, -72.1823], requires_grad=True)
tensor([7.9176e-05, 3.9587e-06]) tensor([  1.2000, -72.1823], requires_grad=True)
tensor([7.9176e-05, 3.9587e-06]) tensor([  1.2000, -72.1823], requires_grad=True)
tensor([27.5159,  0.4577]) tensor([  1.2002, -72.1821], requires_grad=True)
tensor([8.7094e-05, 4.3546e-06]) tensor([  1.2000, -72.1823], requires_grad=True)
tensor([8.7094e-05, 4.3546e-06]) tensor([  1.2000, -72.1823], requires_grad=True)
tensor([8.7094e-05, 4.3546e-06]) tensor([  1.2000, -72.1823], requires_grad=True)
tensor([8.7094e-05, 4.3546e-06]) tensor([  1.2000, -72.1823], requires_grad=True)
tensor([-0.0151, -0.0002]) tensor([  1.2000, -72.1823], requires_grad=True)
tensor([1.2668e-04, 6.3339e-06]) tensor([  1.2000, -72.1823], requires_grad=True)
tensor([1.2668e-04, 6.3339e-06]) tensor([  1.2000, -72.1823], requires_grad=True)
tensor([1.2668e-04, 6.3339e-06]) tensor([  1.2000, -72.1823], requires_grad=True)
tensor([0.0153, 0.0003]) tensor([  1.2000, -72.1824], requires_grad=True)
tensor([4.7506e-05, 2.3752e-06]) tensor([  1.2000, -72.1824], requires_grad=True)
tensor([-145.9024,   -2.4270]) tensor([  1.1989, -72.1835], requires_grad=True)
tensor([-25.7169,  -0.4278]) tensor([  1.1998, -72.1826], requires_grad=True)
tensor([0., 0.]) tensor([  1.2000, -72.1824], requires_grad=True)
tensor([0., 0.]) tensor([  1.2000, -72.1824], requires_grad=True)
tensor([0., 0.]) tensor([  1.2000, -72.1824], requires_grad=True)
tensor([0., 0.]) tensor([  1.2000, -72.1824], requires_grad=True)
tensor([0., 0.]) tensor([  1.2000, -72.1824], requires_grad=True)
tensor([  1.2000, -72.1824], requires_grad=True)

If in your actual use you can reduce your troublesome cases to linear
regression and you have the actual A matrix, just use regular linear
algebra (i.e., lstsq()) to perform the regression. If you somehow
have to use gradient descent, but know that some matrix B is a
good approximation to the “regression” matrix A that lurks inside of
your actual regression problem, you could “precondition” the actual
problem by using B to perform a change of variables that transforms
your ill-conditioned skinny elliptical valley into a well-conditioned
more-nearly-circular valley, and gradient descent will work just fine
with a comfortably large learning rate.

Best.

K. Frank