Trying to make the neural network approximate a custom function

# I will try to verify the universal approximation theorem on an arbitrary function 

import torch
from torch import nn
from torch.autograd import Variable
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import torch.optim as optim

def f_x(x):
    return 2*x*x + 3*x*x # random function to learn


# Building dataset
def build_dataset():
    # Given f(x), is_f_x defines whether the function is satisfied
    data = []
    for i in range(1,100):
        data.append((i,f_x(i), 1)) # True
    for j in range(100, 201):
        data.append((j,f_x(j)+j*j, 0)) # Not true
    column_names = ["x","f_x", "is_f_x"]
    df = pd.DataFrame(data, columns=column_names)
    return df

df = build_dataset()
print ("Dataset is built!")

    
labels = df.is_f_x.values
features = df.drop(columns=['is_f_x']).values

print ("shape of features:", features.shape)
print ("shape of labels: ", labels.shape)


# Building nn
net = nn.Sequential(nn.Linear(features.shape[1],100), nn.ReLU(), nn.Linear(100, 100), nn.ReLU(),nn.Linear(100, 2))

features_train, features_test, labels_train, labels_test = train_test_split(features, labels, shuffle=True, random_state=34)

# parameters
optimizer = optim.Adam(net.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
epochs=300


def train():
    net.train()
    losses = []
    for epoch in range(1,200):
        x_train = Variable(torch.from_numpy(features_train)).float()
        y_train = Variable(torch.from_numpy(labels_train)).long()
        y_pred = net(x_train)
        loss = criterion(y_pred, y_train)
        print ("epoch #", epoch)
        print (loss.item())
        losses.append(loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    return losses


print ("training start....")
losses = train()
plt.plot(range(1, 200), losses)
plt.xlabel("epoch")
plt.ylabel("loss train")
plt.show()

print ("testing start ... ")
x_test = Variable(torch.from_numpy(features_test)).float()
x_train = Variable(torch.from_numpy(features_train)).float()


def test():
    pred = net(x_test)
    pred = torch.max(pred, 1)[1]
    print ("Accuracy on test set: ", accuracy_score(labels_test, pred.data.numpy()))

    p_train = net(x_train)
    p_train = torch.max(p_train, 1)[1]
    print ("Accuracy on train set: ", accuracy_score(labels_train, p_train.data.numpy()))
           
test()

Dataset is built!
shape of features: (200, 2)
shape of labels: (200,)
training start…
epoch # 1
713.3623657226562
epoch # 2
231.61741638183594
epoch # 3
2397.51220703125
epoch # 4
19.050514221191406
epoch # 5
340.9588623046875
epoch # 6
608.4653930664062
epoch # 7
818.8595581054688
epoch # 8
983.2216796875
epoch # 9
1109.30078125
epoch # 10
1202.837890625
epoch # 11
1268.2791748046875
epoch # 12
1309.0794677734375
epoch # 13
1328.2281494140625
epoch # 14
1328.1885986328125
epoch # 15
1311.066162109375
epoch # 16
1278.677490234375
epoch # 17
1232.6102294921875
epoch # 18
1174.258544921875
epoch # 19
1104.8541259765625
epoch # 20
1025.4881591796875
epoch # 21
937.1278076171875
epoch # 22
840.6298217773438
epoch # 23
736.751953125
epoch # 24
626.2609252929688
epoch # 25
509.780517578125
epoch # 26
387.73150634765625
epoch # 27
260.4958190917969
epoch # 28
128.44473266601562
epoch # 29
83.62914276123047
epoch # 30
24.77330780029297
epoch # 31
40.1733283996582
epoch # 32
39.64967727661133
epoch # 33
24.587085723876953
epoch # 34
42.105003356933594
epoch # 35
97.61442565917969
epoch # 36
176.40957641601562
epoch # 37
234.59713745117188
epoch # 38
273.9774475097656
epoch # 39
296.18109130859375
epoch # 40
302.68768310546875
epoch # 41
294.8420104980469
epoch # 42
273.86785888671875
epoch # 43
240.8780975341797
epoch # 44
196.88392639160156
epoch # 45
142.802978515625
epoch # 46
79.46659088134766
epoch # 47
7.649067401885986
epoch # 48
688.4632568359375
epoch # 49
112.88740539550781
epoch # 50
156.5810546875
epoch # 51
296.35931396484375
epoch # 52
410.4731140136719
epoch # 53
501.2221374511719
epoch # 54
570.68896484375
epoch # 55
620.7679443359375
epoch # 56
653.1859741210938
epoch # 57
669.5223388671875
epoch # 58
671.2222900390625
epoch # 59
659.607421875
epoch # 60
635.8880615234375
epoch # 61
601.1698608398438
epoch # 62
556.4600830078125
epoch # 63
502.6763000488281
epoch # 64
440.6499328613281
epoch # 65
371.1325378417969
epoch # 66
294.8001708984375
epoch # 67
212.25828552246094
epoch # 68
124.04632568359375
epoch # 69
30.65134048461914
epoch # 70
643.6390991210938
epoch # 71
268.0551452636719
epoch # 72
121.89012908935547
epoch # 73
244.99267578125
epoch # 74
343.8369140625
epoch # 75
420.6217956542969
epoch # 76
476.84332275390625
epoch # 77
514.6380615234375
epoch # 78
535.780029296875
epoch # 79
541.8413696289062
epoch # 80
531.9910888671875
epoch # 81
509.8548583984375
epoch # 82
475.6784362792969
epoch # 83
430.72906494140625
epoch # 84
376.0613098144531
epoch # 85
312.5922546386719
epoch # 86
241.1291961669922
epoch # 87
162.3844451904297
epoch # 88
75.0332260131836
epoch # 89
170.226318359375
epoch # 90
32.34799575805664
epoch # 91
65.66975402832031
epoch # 92
82.4910659790039
epoch # 93
84.09539031982422
epoch # 94
72.26105499267578
epoch # 95
48.515533447265625
epoch # 96
14.012791633605957
epoch # 97
272.3330078125
epoch # 98
50.459129333496094
epoch # 99
110.71949005126953
epoch # 100
151.30946350097656
epoch # 101
173.5558319091797
epoch # 102
179.85635375976562
epoch # 103
171.81048583984375
epoch # 104
150.80650329589844
epoch # 105
118.06989288330078
epoch # 106
74.69754028320312
epoch # 107
21.80451202392578
epoch # 108
320.6127014160156
epoch # 109
32.47548294067383
epoch # 110
84.9989013671875
epoch # 111
119.59487915039062
epoch # 112
137.81689453125
epoch # 113
139.64547729492188
epoch # 114
127.34384155273438
epoch # 115
102.155517578125
epoch # 116
65.26541137695312
epoch # 117
17.089582443237305
epoch # 118
308.3915100097656
epoch # 119
37.48691177368164
epoch # 120
96.29904174804688
epoch # 121
135.45509338378906
epoch # 122
157.54776000976562
epoch # 123
164.39126586914062
epoch # 124
157.47389221191406
epoch # 125
138.11956787109375
epoch # 126
107.50679779052734
epoch # 127
66.71660614013672
epoch # 128
17.0440673828125
epoch # 129
318.98602294921875
epoch # 130
29.648714065551758
epoch # 131
80.84506225585938
epoch # 132
114.38468933105469
epoch # 133
131.8206024169922
epoch # 134
134.68080139160156
epoch # 135
124.35076904296875
epoch # 136
102.0821304321289
epoch # 137
69.0161361694336
epoch # 138
26.332563400268555
epoch # 139
165.07492065429688
epoch # 140
49.77608871459961
epoch # 141
104.92957305908203
epoch # 142
142.0511016845703
epoch # 143
162.8252410888672
epoch # 144
168.81472778320312
epoch # 145
161.43153381347656
epoch # 146
141.9481658935547
epoch # 147
111.51407623291016
epoch # 148
71.18148803710938
epoch # 149
22.130130767822266
epoch # 150
257.24371337890625
epoch # 151
34.31955337524414
epoch # 152
84.4492416381836
epoch # 153
117.15431213378906
epoch # 154
134.0120086669922
epoch # 155
136.53741455078125
epoch # 156
126.10259246826172
epoch # 157
103.947265625
epoch # 158
71.20113372802734
epoch # 159
29.021305084228516
epoch # 160
128.76654052734375
epoch # 161
52.15217208862305
epoch # 162
106.51445770263672
epoch # 163
143.05496215820312
epoch # 164
163.44993591308594
epoch # 165
169.2525177001953
epoch # 166
161.86529541015625
epoch # 167
142.55288696289062
epoch # 168
112.45751190185547
epoch # 169
72.62442016601562
epoch # 170
24.217023849487305
epoch # 171
225.94088745117188
epoch # 172
36.39259719848633
epoch # 173
85.9130859375
epoch # 174
118.19615936279297
epoch # 175
134.8114776611328
epoch # 176
137.2653350830078
epoch # 177
126.92183685302734
epoch # 178
105.0134506225586
epoch # 179
72.66358184814453
epoch # 180
31.018314361572266
epoch # 181
99.57691955566406
epoch # 182
54.042625427246094
epoch # 183
107.81678009033203
epoch # 184
143.9482879638672
epoch # 185
164.10391235351562
epoch # 186
169.82797241210938
epoch # 187
162.51536560058594
epoch # 188
143.423095703125
epoch # 189
113.68685913085938
epoch # 190
74.34576416015625
epoch # 191
26.543180465698242
epoch # 192
193.11288452148438
epoch # 193
38.73274612426758
epoch # 194
87.72059631347656
epoch # 195
119.64286804199219
epoch # 196
136.06314086914062
epoch # 197
138.47836303710938
epoch # 198
128.2442626953125
epoch # 199
106.58541107177734
testing start …
Accuracy on test set: 0.52
Accuracy on train set: 0.58

Any idea on how to make the network “learn” better?

As you can see, your loss is quite “bouncy”. While it is really low in epoch 4, it grows to a large value, and shrinks again.
This is often an indicator of a high learning rate.
Lower your learning rate to 1e-5 and you’ll get ~99% accuracy on both sets.

getting 44% and 50 % on train and test respectively?

Maybe you just got unlucky.
You could add a weight init scheme to your code to stabilize the training:

def weight_init(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight.data)
       m.bias.data.zero_()

model.apply(weight_init)

I got in 8 out of 10 runs a perfect score.

1 Like