Solve least squares with neural network

Hi,

I have the following toy example code:

    class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(128, 128, bias=False)

    def forward(self, x):
        x = self.fc1(x)
        return x


    x = torch.randn((128, 128))
    y = torch.randn((128, 128))

    U = torch.linalg.lstsq(x, y).solution
    pred = x @ U
    print(f'norm: {torch.norm(pred - y)}')

    net = Net()
    optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
    for i in range(1000):
        pred = net(x)
        loss = torch.norm(pred - y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print(f'loss: {loss.item()}')

I initialize two random vectors x and y and try to find transformation matrix U such that x @ U = y.
The solution with the least squares obtained a low norm (loss). Then I try to solve the same problem by learning U, however I cant get it to converge to a low loss and the estimation U learned poorly perform. Is there any idea how can I learn a better estimation of U using learning algorithms? I also tried to define the forward pass of Net as follows:

    def forward(self, x):
        x = x @ self.fc1.weight
        return x

But it doesn’t seem to help.

Hi Adir!

Your x and y vectors are sampled independently. Therefore, on a
statistical basis, they have nothing to do with one another, so you can’t
really regress one against the other.

For your specific values of x and y that you have sampled, some of the
statistical variations they have will happen – just by chance – to line up,
so you will be able to regress the concrete values of y against x, but it
will be hard to do (and if you did so again with a new set of samples for
x and y, you would get entirely different values for the regression
coefficients).

A consequence of this is that your specific concrete regression problem
is ill-conditioned, as can be seen from the large condition number of
your matrix U.

I’ve tweaked and expanded the code you posted to illustrate what is
going on. First, you can train, but doing so it difficult and slow.

I then construct a rather similar regression problem where the tensor
t does depend on the tensor s, so the regression problem makes
sense and is not ill-conditioned. Your Net model trains just fine on
this regression problem and produces an accurate result.

Here is the example script:

import torch
print (torch.__version__)

_ = torch.manual_seed (2023)

class Net (torch.nn.Module):
    def __init__ (self):
        super (Net, self).__init__()
        self.fc1 = torch.nn.Linear (128, 128, bias = False)
    
    def forward (self, x):
        x = self.fc1 (x)
        return x

print ('original ill-conditioned regression')
x = torch.randn ((128, 128))
y = torch.randn ((128, 128))

U = torch.linalg.lstsq (x, y).solution
pred = x @ U
print (f'norm: {torch.norm (pred - y)}')
print (f'cond (U): {torch.linalg.cond (U)}')   # large condition number -- problem is ill-conditioned

genericV = torch.randn (128, 128)
print (f'cond (genericV): {torch.linalg.cond (genericV)}')   # "generic" value of condition number

net = Net()
# optimizer = torch.optim.Adam (net.parameters(), lr = 1e-3)
optimizer = torch.optim.Adam (net.parameters(), lr = 2e-3)
for i in range (1000001):
    pred = net (x)
    loss = torch.norm (pred - y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if  i % 100000 == 0:
        diff = (net.fc1.weight.T - U).abs().mean()
        print ('i: %7d    loss: %8.3f    cond (weight): %9.2f    diff: %9.3e' % (i, loss.item(), torch.linalg.cond (net.fc1.weight), diff))

# create non-ill-conditioned regression
print ('non-ill-conditioned regression -- because of noise, loss will not go to zero')
s = torch.randn (128, 128)
t = s @ genericV
# add some noise
s = s * (1.0 + 0.001 * torch.randn (128, 128))
t = t * (1.0 + 0.001 * torch.randn (128, 128))

V = torch.linalg.lstsq (s, t).solution
pred = s @ V
print (f'norm: {torch.norm (pred - t)}')       # norm (loss) is not zero
print (f'cond (V): {torch.linalg.cond (V)}')   # problem is not ill-conditioned

net = Net()
optimizer = torch.optim.Adam (net.parameters(), lr = 2e-4)
for i in range (200001):
    pred = net (s)
    loss = torch.norm (pred - t)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if  i % 20000 == 0:
        diff = (net.fc1.weight.T - V).abs().mean()
        print ('i: %7d    loss: %8.3f    cond (weight): %9.2f    diff: %9.3e' % (i, loss.item(), torch.linalg.cond (net.fc1.weight), diff))

And here is its output:

1.13.1
original ill-conditioned regression
norm: 0.00715411314740777
cond (U): 434530.90625
cond (genericV): 328.3016052246094
i:       0    loss:  147.919    cond (weight):    648.61    diff: 9.353e+00
i:  100000    loss:    7.867    cond (weight):  99704.01    diff: 7.269e+00
i:  200000    loss:    6.126    cond (weight): 170475.48    diff: 5.660e+00
i:  300000    loss:    4.770    cond (weight): 228174.14    diff: 4.407e+00
i:  400000    loss:    3.715    cond (weight): 268940.16    diff: 3.431e+00
i:  500000    loss:    2.893    cond (weight): 310605.19    diff: 2.671e+00
i:  600000    loss:    2.253    cond (weight): 335553.88    diff: 2.079e+00
i:  700000    loss:    1.757    cond (weight): 360560.00    diff: 1.618e+00
i:  800000    loss:    1.371    cond (weight): 388058.84    diff: 1.260e+00
i:  900000    loss:    1.070    cond (weight): 390163.25    diff: 9.805e-01
i: 1000000    loss:    0.839    cond (weight): 402945.47    diff: 7.631e-01
non-ill-conditioned regression -- because of noise, loss will not go to zero
norm: 0.0005993304657749832
cond (V): 345.5065002441406
i:       0    loss: 1472.513    cond (weight):    423.85    diff: 8.112e-01
i:   20000    loss:    0.515    cond (weight):   2645.33    diff: 1.038e-01
i:   40000    loss:    0.172    cond (weight):    388.42    diff: 3.485e-02
i:   60000    loss:    0.059    cond (weight):    333.62    diff: 1.167e-02
i:   80000    loss:    0.024    cond (weight):    341.21    diff: 3.906e-03
i:  100000    loss:    0.016    cond (weight):    344.06    diff: 1.308e-03
i:  120000    loss:    0.014    cond (weight):    345.06    diff: 4.388e-04
i:  140000    loss:    0.015    cond (weight):    345.36    diff: 1.480e-04
i:  160000    loss:    0.015    cond (weight):    345.39    diff: 5.397e-05
i:  180000    loss:    0.014    cond (weight):    345.44    diff: 2.335e-05
i:  200000    loss:    0.015    cond (weight):    345.55    diff: 1.374e-05

Best.

K. Frank

2 Likes

Hi, many thanks for your answer! It was very helpful. The code I attached here is an example(and now I understand why it was a bad example). What I am trying to learn is a U that will map between weights matrices of the same model that was trained with different seeds, and I can’t get the loss down. When I now look at the condition number of the U matrix obtained by least squares, I get a very high value. So you are actually saying the condition number of least squares U indicates the stability of the learning process? Is there anything that can be done in-order to improve stability?