Failing to learn a quadratic function with pytorch

Hi
I am new to pytorch and when I tried to use pytorch in my project I noticed that somehow it always predicts straight lines. I tried to isolate the problem and I completly failed to approximate a normal quadratic function with it. I very confused where I go wrong…

    import torch
    from torch import nn
    from torch.autograd import Variable
    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt
    from sklearn.model_selection import train_test_split
    from sklearn.metrics import accuracy_score
    import torch.optim as optim


    def f_x(x):
        return x * x / 64 - 5 * x / 4 + 25


    # Building dataset
    def build_dataset():
        # Given f(x), is_f_x defines whether the function is satisfied
        x_values = np.ones((21, 1))
        for i in range(0, 21):
            x_values[i] = i + 30  # True
        return x_values


    x_values = build_dataset()

    # Building nn
    # net = nn.Sequential(nn.Linear(1, 100), nn.ReLU(), nn.Linear(100, 100), nn.ReLU(), nn.Linear(100, 1))
    net = nn.Sequential(nn.Linear(1, 1000), nn.ReLU(), nn.Linear(1000, 1000), nn.ReLU(), nn.Linear(1000, 1000), nn.ReLU(), nn.Linear(1000, 1))

    # parameters
    optimizer = optim.Adam(net.parameters(), lr=0.00001)
    epochs = 200


    def out(k):
        # folder_name = "Testrun1"
        # working_directory = pathlib.Path().absolute()
        # output_location = working_directory / f'{folder_name}'

        a = 30
        b = 50

        # TODO: copy graph so i only use a copy when it was still open

        import matplotlib.backends.backend_pdf as pdfp
        from pylab import plot, show, grid, xlabel, ylabel
        import matplotlib.pyplot as plt
        # pdf = pdfp.PdfPages("graph" + str(k) + ".pdf")

        t = np.linspace(a, b, 20)
        x = np.zeros(t.shape[0])
        c_fig = plt.figure()

        for j in range(len(t)):
            h = torch.tensor(np.ones(1) * t[j], dtype=torch.float32)
            x[j] = net(h)
        plt.ylim([0, 1])
        plot(t, x, linewidth=4)
        xlabel('x', fontsize=16)
        ylabel('net(x)', fontsize=16)
        grid(True)
        show()
        # pdf.savefig(c_fig)

        # pdf.close()
        plt.close(c_fig)


    def train():
        net.train()
        losses = []
        for epoch in range(1, epochs):
            x_train = Variable(torch.from_numpy(x_values)).float()
            y_train = f_x(x_train)
            y_pred = net(x_train)
            loss = torch.sum(torch.abs(y_pred - y_train))
            print("epoch #", epoch)
            print(loss.item())
            losses.append(loss.item())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        return losses


    print("training start....")
    losses = train()
    plt.plot(range(1, epochs), losses)
    plt.xlabel("epoch")
    plt.ylabel("loss train")
    plt.ylim([0, 100])
    plt.show()

    out(epochs)

RELU is not sophisticated enough here - it is difficult to find good thresholds from univariate input, so optimization rather finds inferior local minimum using linear region of RELU activations.
Also note that adding x^2 to network inputs explicitly considerably simplifies the task.

1 Like

Hi
You were definitly correct that RELU was not an optimal choice. I used this code

activation_function = nn.SELU()
net = nn.Sequential(nn.Linear(1, 1000), activation_function, nn.Linear(1000, 1000), activation_function, nn.Linear(1000, 1000), activation_function, nn.Linear(1000, 1))

and plugged in all 20 non-linear activation functions from here https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity

there were multiple where it was possible to see a curve but out of the 20 exactly one reached an somewhatdecent precision and that was SELU. However even the best activation function retained a loss of >= 1 regardless of the number of iterations, which I consider unacceptable given the task to approximate a quadratic function with 2 hidden layer of 1000 neurons.

Do you have another idea that might help? I tried different optimizers as well, but Adam + Rprop are the two that work best, so I can’t make an improvement there…

@your note: You mean I should have 2 input values, one is x and the other is x*x? I don’t get why this is better but I do believe you.

That function is just not easy to learn, starting from noise. It also doesn’t help that inputs and outputs (target bias=25) are far from zero.

  1. amount of training may be too low (lr_rate and number of epochs). lr_scheduler may also be needed to fine-tune curves
  2. you’re not using quadratic loss (MSE). this means that y_pred derivative doesn’t depend on distance from target, just on its sign
  3. excessive network size just complicates training

I actually ran you code, with:

net = nn.Sequential(nn.Linear(1, 50), nn.Tanh(), nn.Linear(50,50), nn.Tanh(), nn.Linear(50, 1))
optimizer = optim.Adam(net.parameters(), lr=0.01)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 0.99)
epochs = 1000
loss = ((y_pred-y_train)**2).sum()

gives kinda reasonable curve. But with your big network (3layers x 1000, tanh), I decreased lr to 0.0001 (0.01 gave linear solutions) and 1000 epochs was not enough.

So, in general, it is about tuning hyperparameters, all smooth activation functions should in principle work.

1 Like

I am very happy with this solution.

As far as I can tell the most important problems were wrong activation/loss functions. After these changes many nets seem to work, but your small net does the trick fast and very accurate. I am also very happy you showed an example of the lr_scheduler. I saw them a few days ago but forgot about them again.

However I do not understand your point about target bias. What exactly is target bias and how do you come up with the exact number of 25?

Edit: I just noticed very strange behaviour. I ran my code with all your chnges a few times and it always looks like the loss reaches a plateau at around 1 and stays there for a long time, then suddenly it gets much better. The length of this plateau varies between 200 and 800 epochs in my tests. Do you have an idea why this is? Especially when it stayed really constant at loss 1 for 800 epochs I surely would have stopped the net if it wasn’t so fast.

Your final linear layer does computation of form dot(h,w)+b, with w and b initialized to zero-centered noise, it is very far from true function f_x(x)=P(x)+25, so this offset may be modeled through h coming from earlier layers. But h is obtained via non-linear transform, so approximations suffer somewhat.

Not sure, it seems that curve fitting starts late, network has to shift things to a good region first. lr scheduling may also play a role.
Note that tanh is actually problematic, it has tiny gradients outside -5…5, and you’re using non-normalized inputs. So another possible reason is that some “neurons” are effectively disabled for a long time.

1 Like