High and volatile loss, no learning at all

I need some insights on high and volatile loss. It is increasing over the course of a few epochs, and then stays in a very volatile and high range.
I have plotted loss, rms and r² score.

Is there an error in my code? Or do I need more training data?

import torch
import numpy as np
from torch.utils.data import DataLoader, TensorDataset, IterableDataset
import os
from torch import optim
from model import Network
from matplotlib import pyplot as plt
from sklearn.metrics import r2_score

path = os.getcwd()

inputs = np.load(path + '/data/in1a.npy')
outputs = np.load(path + '/data/out1a.npy')

batch_size = 400

inputs_0 = torch.Tensor(inputs[0,:])
inputs_1 = torch.Tensor(inputs[1,:])

outputs_0 = torch.Tensor(outputs[0,:])
outputs_1 = torch.Tensor(outputs[1,:])
outputs_2 = torch.Tensor(outputs[2,:])


inputs_tensor = torch.Tensor(inputs)

dataset_in = TensorDataset(inputs_0, inputs_1)
dataset_out = TensorDataset(outputs_0, outputs_1, outputs_2)

dataloader_in = DataLoader(dataset_in, batch_size=batch_size, shuffle=False)
dataloader_out = DataLoader(dataset_out, batch_size=batch_size, shuffle = False)

number_batches_in = int(len(dataset_in)/batch_size)
number_batches_out = int(len(dataset_out)/batch_size)

x = torch.empty(size=(number_batches_in, 800))
y = torch.empty(size=(number_batches_out,1200), dtype=torch.float64)

for index, (x1, x2) in enumerate(dataloader_in):
    batch = torch.cat((x1, x2), 0)
    x[index] = batch

x_mean = torch.mean(x)
x_std = torch.std(x)
x_norm = (x - x_mean) / x_std

for index, (y1, y2, y3) in enumerate(dataloader_out):
    batch = torch.cat((y1, y2, y3), 0)
    y[index] = batch

y_mean = torch.mean(y)
y_std = torch.std(y)
y_norm = (y - y_mean) / y_std


model = Network(800,1200,3,800,200)
SAVE_PATH = "trained/model.dat"
epochs = 200
learning_rate = 0.001
optimizer = optim.Adam(model.parameters(),lr=learning_rate, eps=1e-08)
hist_error = []
hist_loss = []
hist_r2 = []
beta = 0.5

for epoch in range(epochs):
    epoch_error = []
    epoch_loss = []
    epoch_r2 = []
    for x_batch, y_true in zip(x_norm, y_norm):
        optimizer.zero_grad()
        x_batch = torch.unsqueeze(x_batch, 0)
        y_true = torch.unsqueeze(y_true, 0)
        pred = model.forward(x_batch)
        loss = torch.mean(torch.sum((pred - y_true)) ** 2)
        loss.backward()
        optimizer.step()
        error = torch.mean(torch.sqrt((pred - y_true) ** 2)).detach().numpy()
        r2y = torch.squeeze(y_true, 0)
        r2pred = torch.squeeze(pred, 0)
        r2 = r2_score(r2y.detach().numpy(), r2pred.detach().numpy())
        epoch_error.append(error)
        epoch_loss.append(loss.data.detach().numpy())
        epoch_r2.append(r2)
    hist_error.append(np.mean(epoch_error))
    hist_loss.append(np.mean(epoch_loss))
    hist_r2.append(np.mean(epoch_r2))
    print("Epoch %d -- loss %f, RMS error %f, R² score %f " % (epoch+1, hist_loss[-1], hist_error[-1], hist_r2[-1]))
torch.save(model.state_dict(), SAVE_PATH)
print("Model saved to %s" % SAVE_PATH)


f, (ax1, ax2, ax3) = plt.subplots(3, 1, sharex=True)
ax1.plot(hist_error)
ax1.set_ylabel("Amplitude RMSE")
ax2.plot(hist_loss)
ax2.set_ylabel("Loss")
ax3.plot(hist_r2)
ax3.set_ylabel("R²")
ax3.set_xlabel("Epoch")

plt.show()
import torch
import torch.nn as nn
import torch.nn.functional as F


class Network(nn.Module):
    def __init__(self, input_dim, output_dim, latent_dim, layer_dim1, layer_dim2):
        """
        Parameter:
        input_dim (int): number of inputs
        output_dim (int): number of outputs
        latent_dim (int): number of latent neurons
        Layer_dim (int): number of neurons in hidden layers
        """
        super(Network, self).__init__()
        self.latent_dim = latent_dim

        self.enc1 = nn.Linear(input_dim, layer_dim1)
        self.enc2 = nn.Linear(layer_dim1, layer_dim2)

        self.latent = nn.Linear(layer_dim2, latent_dim*2)

        self.dec1 = nn.Linear(latent_dim, layer_dim2)
        self.dec2 = nn.Linear(layer_dim2, layer_dim1)

        self.out = nn.Linear(layer_dim1, output_dim)

    def encoder(self, x):
        z = F.elu(self.enc1(x))
        z = F.elu(self.enc2(z))
        z = self.latent(z)
        self.mu = z[:,0:self.latent_dim]
        self.log_sigma = z[:,self.latent_dim:]
        self.sigma = torch.exp(self.log_sigma)

        eps = torch.randn(x.size(0), self.latent_dim)
        z_sample = self.mu + self.sigma * eps

        self.kl_loss = kl_divergence(self.mu, self.log_sigma, dim=self.latent_dim)

        return z_sample

    def decoder(self, z):
        x = F.elu(self.dec1(z))
        x = F.elu(self.dec2(x))
        return self.out(x)

    def forward(self, batch):
        self.latent_rep = self.encoder(batch)
        dec_input = self.latent_rep
        return self.decoder(dec_input)


def kl_divergence(means, log_sigma, dim, target_sigma=0.1):
    """
    Computes Kullback–Leibler divergence for arrays of mean and log(sigma)
    """
    target_sigma = torch.Tensor([target_sigma])
    inner = 1 / target_sigma**2 * means**2 + torch.exp(2 * log_sigma) / target_sigma**2 - 2 * log_sigma + 2 * torch.log(target_sigma)
    inner = torch.mean(inner, dim = 0)
    out = 1 / 2. * torch.mean(inner - dim)
    out = out
    return out


The first 20 epochs:

Epoch 1 -- loss 2592.617822, RMS error 0.769075, R² score -0.216357 
Epoch 2 -- loss 1020.093633, RMS error 0.767074, R² score -0.159109 
Epoch 3 -- loss 1855.906876, RMS error 0.758663, R² score -0.079950 
Epoch 4 -- loss 2494.895288, RMS error 0.760361, R² score -0.097916 
Epoch 5 -- loss 1447.344017, RMS error 0.759499, R² score -0.094597 
Epoch 6 -- loss 1500.075461, RMS error 0.760647, R² score -0.133853 
Epoch 7 -- loss 2715.783510, RMS error 0.760349, R² score -0.113949 
Epoch 8 -- loss 4040.605946, RMS error 0.760447, R² score -0.097853 
Epoch 9 -- loss 2115.793215, RMS error 0.758570, R² score -0.090718 
Epoch 10 -- loss 2927.331250, RMS error 0.758347, R² score -0.090624 
Epoch 11 -- loss 5020.661278, RMS error 0.761674, R² score -0.122575 
Epoch 12 -- loss 5662.833613, RMS error 0.759100, R² score -0.088969 
Epoch 13 -- loss 2570.623434, RMS error 0.757806, R² score -0.048262 
Epoch 14 -- loss 1901.247553, RMS error 0.758893, R² score -0.064967 
Epoch 15 -- loss 13875.499270, RMS error 0.760261, R² score -0.087031 
Epoch 16 -- loss 16458.334455, RMS error 0.758352, R² score -0.057363 
Epoch 17 -- loss 17253.359936, RMS error 0.760583, R² score -0.080694 
Epoch 18 -- loss 17851.229393, RMS error 0.760331, R² score -0.077037 
Epoch 19 -- loss 24284.866304, RMS error 0.758882, R² score -0.066919 
Epoch 20 -- loss 20997.437961, RMS error 0.760005, R² score -0.084857 

myplot

Your loss seems to be oscillating. You may try reducing LR to 1e-5 or 1e-4. Or even try scheduling LR. This can be tried as a first way method.

I adjusted the LR to 0.0001 and 0.00001 respectively. 1e-5 shows some improvement, I will let it run for some more epochs.

1e-4:

Epoch 1 -- loss 23092.178817, RMS error 0.796679, R² score -0.764875 
Epoch 2 -- loss 29207.602612, RMS error 0.786404, R² score -0.489875 
Epoch 3 -- loss 30938.809148, RMS error 0.781764, R² score -0.447401 
Epoch 4 -- loss 36056.395415, RMS error 0.783350, R² score -0.420444 
Epoch 5 -- loss 38970.840979, RMS error 0.779231, R² score -0.424409 
Epoch 6 -- loss 38725.751165, RMS error 0.781632, R² score -0.482981 
Epoch 7 -- loss 53787.240595, RMS error 0.780569, R² score -0.483686 
Epoch 8 -- loss 67411.768857, RMS error 0.783626, R² score -0.524781 
Epoch 9 -- loss 55415.253039, RMS error 0.782485, R² score -0.469684 
Epoch 10 -- loss 156964.512527, RMS error 0.796239, R² score -1.015255 
Epoch 11 -- loss 154036.751858, RMS error 0.797862, R² score -1.176710 
Epoch 12 -- loss 136232.229137, RMS error 0.799174, R² score -1.175171 
Epoch 13 -- loss 54970.134525, RMS error 0.780720, R² score -0.520048 
Epoch 14 -- loss 77784.615496, RMS error 0.785365, R² score -0.509746 
Epoch 15 -- loss 174066.987240, RMS error 0.798963, R² score -1.201877 
Epoch 16 -- loss 129737.510334, RMS error 0.790259, R² score -0.897164 
Epoch 17 -- loss 123497.888881, RMS error 0.791715, R² score -0.918957 
Epoch 18 -- loss 51374.109922, RMS error 0.780750, R² score -0.554175 
Epoch 19 -- loss 61803.601078, RMS error 0.782392, R² score -0.486701 
Epoch 20 -- loss 100403.409419, RMS error 0.785900, R² score -0.647144 

1e-5:

Epoch 1 -- loss 176985.034656, RMS error 0.811731, R² score -1.415746 
Epoch 2 -- loss 160285.464975, RMS error 0.810098, R² score -1.328036 
Epoch 3 -- loss 91367.396519, RMS error 0.793799, R² score -0.854228 
Epoch 4 -- loss 15116.144004, RMS error 0.782176, R² score -0.526900 
Epoch 5 -- loss 4085.806400, RMS error 0.780666, R² score -0.492350 
Epoch 6 -- loss 4382.456566, RMS error 0.781868, R² score -0.519753 
Epoch 7 -- loss 7855.515105, RMS error 0.783714, R² score -0.562923 
Epoch 8 -- loss 6698.984660, RMS error 0.782674, R² score -0.540443 
Epoch 9 -- loss 4957.788071, RMS error 0.781994, R² score -0.524670 
Epoch 10 -- loss 4914.508950, RMS error 0.782255, R² score -0.525899 
Epoch 11 -- loss 5427.704976, RMS error 0.781978, R² score -0.517649 
Epoch 12 -- loss 4490.562387, RMS error 0.781281, R² score -0.501221 
Epoch 13 -- loss 3955.365406, RMS error 0.780896, R² score -0.490551 
Epoch 14 -- loss 5101.983363, RMS error 0.781299, R² score -0.496392 
Epoch 15 -- loss 4485.993145, RMS error 0.780489, R² score -0.477653 
Epoch 16 -- loss 3481.978184, RMS error 0.780190, R² score -0.460933 
Epoch 17 -- loss 3528.847296, RMS error 0.779919, R² score -0.458042 
Epoch 18 -- loss 3748.922880, RMS error 0.779473, R² score -0.451958 
Epoch 19 -- loss 3778.684385, RMS error 0.779273, R² score -0.442150 
Epoch 20 -- loss 3355.956672, RMS error 0.778943, R² score -0.435794 

My initial thought was to simplify the autoencoder architecture, and have just 1 hidden layer before it goes to the latent layer.
The loss there was still volatile, but RMS and R² showed some improvement, however, I guess it was just learning alot slower, hence the ‘nice learning curve’.
myplot_1layer

1e-5 learning rate did show some improvements! Is it maybe just that the training data is not enough with ‘only’ 750 samples?

Epoch 1 -- loss 180403.743162, RMS error 0.821906, R² score -1.487336 
Epoch 2 -- loss 169104.784694, RMS error 0.813036, R² score -1.340975 
Epoch 3 -- loss 120113.906884, RMS error 0.802202, R² score -1.047514 
Epoch 4 -- loss 29157.961805, RMS error 0.785714, R² score -0.568902 
Epoch 5 -- loss 4992.730479, RMS error 0.781753, R² score -0.486485 
Epoch 6 -- loss 3627.731972, RMS error 0.782301, R² score -0.496645 
Epoch 7 -- loss 7142.956734, RMS error 0.784754, R² score -0.543794 
Epoch 8 -- loss 7338.657687, RMS error 0.784554, R² score -0.545328 
Epoch 9 -- loss 6022.291541, RMS error 0.783950, R² score -0.536973 
Epoch 10 -- loss 5052.077289, RMS error 0.783340, R² score -0.528069 
Epoch 11 -- loss 5063.218874, RMS error 0.782701, R² score -0.520601 
Epoch 12 -- loss 5061.055077, RMS error 0.782580, R² score -0.520378 
Epoch 13 -- loss 4894.807523, RMS error 0.782012, R² score -0.511028 
Epoch 14 -- loss 4616.965249, RMS error 0.781081, R² score -0.498040 
Epoch 15 -- loss 3972.096964, RMS error 0.780448, R² score -0.489568 
Epoch 16 -- loss 3658.420474, RMS error 0.779752, R² score -0.476902 
Epoch 17 -- loss 4099.164799, RMS error 0.779745, R² score -0.476791 
Epoch 18 -- loss 3530.959532, RMS error 0.778743, R² score -0.460950 
Epoch 19 -- loss 3264.816841, RMS error 0.778055, R² score -0.449600 
Epoch 20 -- loss 3814.399353, RMS error 0.778157, R² score -0.452178 
Epoch 21 -- loss 3557.507084, RMS error 0.777643, R² score -0.440340 
Epoch 22 -- loss 3242.826616, RMS error 0.777206, R² score -0.432293 
Epoch 23 -- loss 3354.426185, RMS error 0.777048, R² score -0.427852 
Epoch 24 -- loss 2832.482353, RMS error 0.776305, R² score -0.415661 
Epoch 25 -- loss 2706.315243, RMS error 0.776022, R² score -0.408117 
Epoch 26 -- loss 2885.771273, RMS error 0.776296, R² score -0.411292 
Epoch 27 -- loss 2708.342806, RMS error 0.775907, R² score -0.403748 
Epoch 28 -- loss 2699.617106, RMS error 0.775879, R² score -0.399279 
Epoch 29 -- loss 2437.361530, RMS error 0.775800, R² score -0.394906 
Epoch 30 -- loss 2474.435463, RMS error 0.775427, R² score -0.388304 
Epoch 31 -- loss 2806.154593, RMS error 0.775360, R² score -0.388209 
Epoch 32 -- loss 2485.935910, RMS error 0.774711, R² score -0.376525 
Epoch 33 -- loss 2064.171100, RMS error 0.774260, R² score -0.366088 
Epoch 34 -- loss 2174.573715, RMS error 0.774460, R² score -0.365081 
Epoch 35 -- loss 2329.067517, RMS error 0.774409, R² score -0.363911 
Epoch 36 -- loss 2105.775363, RMS error 0.774249, R² score -0.358457 
Epoch 37 -- loss 2181.093310, RMS error 0.774113, R² score -0.355447 
Epoch 38 -- loss 2079.855788, RMS error 0.773716, R² score -0.348380 
Epoch 39 -- loss 2018.169771, RMS error 0.773485, R² score -0.345159 
Epoch 40 -- loss 1909.393100, RMS error 0.773272, R² score -0.341356 
Epoch 41 -- loss 1794.751873, RMS error 0.773062, R² score -0.335986 
Epoch 42 -- loss 1841.455255, RMS error 0.773007, R² score -0.333332 
Epoch 43 -- loss 1898.677265, RMS error 0.772967, R² score -0.332230 
Epoch 44 -- loss 1897.363214, RMS error 0.772855, R² score -0.329650 
Epoch 45 -- loss 1808.457180, RMS error 0.772539, R² score -0.325173 
Epoch 46 -- loss 1641.110455, RMS error 0.772202, R² score -0.320174 
Epoch 47 -- loss 1586.679415, RMS error 0.771856, R² score -0.315941 
Epoch 48 -- loss 1734.012115, RMS error 0.771908, R² score -0.316184 
Epoch 49 -- loss 1892.380025, RMS error 0.772027, R² score -0.316132 
Epoch 50 -- loss 1864.643358, RMS error 0.771896, R² score -0.315369 
Epoch 51 -- loss 1664.668285, RMS error 0.771572, R² score -0.309368 
Epoch 52 -- loss 1562.034853, RMS error 0.771432, R² score -0.305365 
Epoch 53 -- loss 1461.419723, RMS error 0.771366, R² score -0.303571 
Epoch 54 -- loss 1506.190232, RMS error 0.771246, R² score -0.302183 
Epoch 55 -- loss 1449.853142, RMS error 0.771046, R² score -0.297950 
Epoch 56 -- loss 1423.557770, RMS error 0.770806, R² score -0.295029 
Epoch 57 -- loss 1425.172167, RMS error 0.770863, R² score -0.296160 
Epoch 58 -- loss 1442.614806, RMS error 0.770645, R² score -0.293519 
Epoch 59 -- loss 1475.569339, RMS error 0.770600, R² score -0.291277 
Epoch 60 -- loss 1267.828809, RMS error 0.770311, R² score -0.285899 
Epoch 61 -- loss 1326.803453, RMS error 0.770373, R² score -0.288218 
Epoch 62 -- loss 1370.595468, RMS error 0.770108, R² score -0.285702 
Epoch 63 -- loss 1397.048822, RMS error 0.770019, R² score -0.284827 
Epoch 64 -- loss 1338.152551, RMS error 0.769919, R² score -0.282302 
Epoch 65 -- loss 1171.416068, RMS error 0.769700, R² score -0.277587 
Epoch 66 -- loss 1145.970257, RMS error 0.769698, R² score -0.277153 
Epoch 67 -- loss 1216.038302, RMS error 0.769740, R² score -0.278022 
Epoch 68 -- loss 1302.752022, RMS error 0.769797, R² score -0.279172 
Epoch 69 -- loss 1314.311653, RMS error 0.769856, R² score -0.278785 
Epoch 70 -- loss 1193.560528, RMS error 0.769870, R² score -0.275211 
Epoch 71 -- loss 1056.640736, RMS error 0.769829, R² score -0.272258 
Epoch 72 -- loss 991.398418, RMS error 0.769805, R² score -0.269710 
Epoch 73 -- loss 1037.093583, RMS error 0.769893, R² score -0.270067 
Epoch 74 -- loss 1177.517621, RMS error 0.770055, R² score -0.272101 
Epoch 75 -- loss 1243.690292, RMS error 0.769981, R² score -0.271776 
Epoch 76 -- loss 1118.367068, RMS error 0.769664, R² score -0.268376 
Epoch 77 -- loss 985.773997, RMS error 0.769386, R² score -0.265884 
Epoch 78 -- loss 931.811749, RMS error 0.769227, R² score -0.263303 
Epoch 79 -- loss 1018.003820, RMS error 0.769397, R² score -0.264810 
Epoch 80 -- loss 1167.205052, RMS error 0.769690, R² score -0.269945 
Epoch 81 -- loss 1221.452906, RMS error 0.769536, R² score -0.268773 
Epoch 82 -- loss 1060.727332, RMS error 0.769231, R² score -0.264030 
Epoch 83 -- loss 871.762951, RMS error 0.768961, R² score -0.259052 
Epoch 84 -- loss 829.324191, RMS error 0.768826, R² score -0.256068 
Epoch 85 -- loss 906.068838, RMS error 0.768970, R² score -0.257897 
Epoch 86 -- loss 1088.497688, RMS error 0.769331, R² score -0.263284 
Epoch 87 -- loss 1042.348885, RMS error 0.769310, R² score -0.261818 
Epoch 88 -- loss 878.435139, RMS error 0.768899, R² score -0.256630 
Epoch 89 -- loss 824.462952, RMS error 0.768746, R² score -0.254029 
Epoch 90 -- loss 878.883424, RMS error 0.768902, R² score -0.254908 
Epoch 91 -- loss 943.144427, RMS error 0.768945, R² score -0.255984 
Epoch 92 -- loss 920.859744, RMS error 0.768924, R² score -0.255869 
Epoch 93 -- loss 879.214606, RMS error 0.768721, R² score -0.253799 
Epoch 94 -- loss 846.007895, RMS error 0.768595, R² score -0.251593 
Epoch 95 -- loss 798.920333, RMS error 0.768452, R² score -0.249210 
Epoch 96 -- loss 771.437864, RMS error 0.768439, R² score -0.248403 
Epoch 97 -- loss 793.744618, RMS error 0.768343, R² score -0.247484 
Epoch 98 -- loss 850.234362, RMS error 0.768503, R² score -0.249382 
Epoch 99 -- loss 847.446497, RMS error 0.768585, R² score -0.249569 
Epoch 100 -- loss 818.684796, RMS error 0.768533, R² score -0.249195 
Epoch 101 -- loss 770.633937, RMS error 0.768357, R² score -0.245861 
Epoch 102 -- loss 779.327640, RMS error 0.768412, R² score -0.245629 
Epoch 103 -- loss 774.053371, RMS error 0.768395, R² score -0.245385 
Epoch 104 -- loss 739.709661, RMS error 0.768334, R² score -0.244732 
Epoch 105 -- loss 722.971646, RMS error 0.768333, R² score -0.244295 
Epoch 106 -- loss 744.256830, RMS error 0.768487, R² score -0.244330 
Epoch 107 -- loss 745.896061, RMS error 0.768515, R² score -0.244256 
Epoch 108 -- loss 790.474475, RMS error 0.768492, R² score -0.245051 
Epoch 109 -- loss 729.078649, RMS error 0.768555, R² score -0.243058 
Epoch 110 -- loss 657.155427, RMS error 0.768426, R² score -0.240324 
Epoch 111 -- loss 604.152068, RMS error 0.768177, R² score -0.237416 
Epoch 112 -- loss 640.625870, RMS error 0.768196, R² score -0.239516 
Epoch 113 -- loss 737.739307, RMS error 0.768471, R² score -0.241899 
Epoch 114 -- loss 800.555676, RMS error 0.768475, R² score -0.242790 
Epoch 115 -- loss 674.808095, RMS error 0.768297, R² score -0.240027 
Epoch 116 -- loss 549.181369, RMS error 0.768044, R² score -0.236592 
Epoch 117 -- loss 465.940622, RMS error 0.767847, R² score -0.232217 
Epoch 118 -- loss 514.780255, RMS error 0.768084, R² score -0.234231 
Epoch 119 -- loss 951.332347, RMS error 0.768974, R² score -0.245574 
Epoch 120 -- loss 1056.144804, RMS error 0.768608, R² score -0.244664 
Epoch 121 -- loss 522.278186, RMS error 0.767809, R² score -0.234194 
Epoch 122 -- loss 316.889376, RMS error 0.767185, R² score -0.222851 
Epoch 123 -- loss 923.961721, RMS error 0.766290, R² score -0.205969 
Epoch 124 -- loss 834.801945, RMS error 0.766472, R² score -0.210059 
Epoch 125 -- loss 556.516251, RMS error 0.766309, R² score -0.207860 
Epoch 126 -- loss 663.168528, RMS error 0.766080, R² score -0.203594 
Epoch 127 -- loss 709.709608, RMS error 0.765931, R² score -0.201287 
Epoch 128 -- loss 652.907853, RMS error 0.765875, R² score -0.200382 
Epoch 129 -- loss 664.100282, RMS error 0.765741, R² score -0.198408 
Epoch 130 -- loss 666.113761, RMS error 0.765683, R² score -0.197205 
Epoch 131 -- loss 626.340995, RMS error 0.765616, R² score -0.196449 
Epoch 132 -- loss 645.584598, RMS error 0.765549, R² score -0.195088 
Epoch 133 -- loss 644.249931, RMS error 0.765525, R² score -0.194445 
Epoch 134 -- loss 636.290861, RMS error 0.765503, R² score -0.194205 
Epoch 135 -- loss 619.900504, RMS error 0.765471, R² score -0.193749 
Epoch 136 -- loss 619.187101, RMS error 0.765441, R² score -0.193200 
Epoch 137 -- loss 610.652410, RMS error 0.765407, R² score -0.192485 
Epoch 138 -- loss 605.461617, RMS error 0.765378, R² score -0.192093 
Epoch 139 -- loss 591.781291, RMS error 0.765399, R² score -0.192221 
Epoch 140 -- loss 619.212132, RMS error 0.765357, R² score -0.191269 
Epoch 141 -- loss 584.104862, RMS error 0.765372, R² score -0.191967 
Epoch 142 -- loss 580.297794, RMS error 0.765376, R² score -0.192241 
Epoch 143 -- loss 589.133627, RMS error 0.765398, R² score -0.191881 
Epoch 144 -- loss 571.670219, RMS error 0.765420, R² score -0.192453 
Epoch 145 -- loss 578.838128, RMS error 0.765469, R² score -0.192766 
Epoch 146 -- loss 573.331501, RMS error 0.765457, R² score -0.192607 
Epoch 147 -- loss 567.463478, RMS error 0.765403, R² score -0.192108 
Epoch 148 -- loss 545.784682, RMS error 0.765441, R² score -0.192057 
Epoch 149 -- loss 549.874110, RMS error 0.765498, R² score -0.192589 
Epoch 150 -- loss 567.682026, RMS error 0.765562, R² score -0.192706 
Epoch 151 -- loss 556.036767, RMS error 0.765520, R² score -0.192740 
Epoch 152 -- loss 557.172478, RMS error 0.765508, R² score -0.192277 
Epoch 153 -- loss 532.166643, RMS error 0.765587, R² score -0.193105 
Epoch 154 -- loss 547.272600, RMS error 0.765541, R² score -0.192932 
Epoch 155 -- loss 521.044357, RMS error 0.765543, R² score -0.193258 
Epoch 156 -- loss 537.808089, RMS error 0.765487, R² score -0.192314 
Epoch 157 -- loss 526.921298, RMS error 0.765474, R² score -0.192654 
Epoch 158 -- loss 530.450955, RMS error 0.765495, R² score -0.192426 
Epoch 159 -- loss 509.189383, RMS error 0.765436, R² score -0.192552 
Epoch 160 -- loss 515.526741, RMS error 0.765588, R² score -0.192712 
Epoch 161 -- loss 519.939532, RMS error 0.765672, R² score -0.193093 
Epoch 162 -- loss 512.142004, RMS error 0.765582, R² score -0.193083 
Epoch 163 -- loss 495.417688, RMS error 0.765574, R² score -0.193166 
Epoch 164 -- loss 495.554877, RMS error 0.765613, R² score -0.193212 
Epoch 165 -- loss 511.100692, RMS error 0.765699, R² score -0.193638 
Epoch 166 -- loss 489.774254, RMS error 0.765592, R² score -0.193951 
Epoch 167 -- loss 479.431730, RMS error 0.765599, R² score -0.194000 
Epoch 168 -- loss 488.013785, RMS error 0.765588, R² score -0.193666 
Epoch 169 -- loss 481.807994, RMS error 0.765563, R² score -0.193487 
Epoch 170 -- loss 483.189243, RMS error 0.765706, R² score -0.194517 
Epoch 171 -- loss 476.318890, RMS error 0.765619, R² score -0.194277 
Epoch 172 -- loss 464.659800, RMS error 0.765524, R² score -0.193604 
Epoch 173 -- loss 469.451061, RMS error 0.765607, R² score -0.194221 
Epoch 174 -- loss 474.250653, RMS error 0.765667, R² score -0.194644 
Epoch 175 -- loss 461.424347, RMS error 0.765682, R² score -0.194747 
Epoch 176 -- loss 454.457915, RMS error 0.765565, R² score -0.195052 
Epoch 177 -- loss 461.239112, RMS error 0.765433, R² score -0.193918 
Epoch 178 -- loss 460.285600, RMS error 0.765519, R² score -0.194187 
Epoch 179 -- loss 444.273921, RMS error 0.765631, R² score -0.194484 
Epoch 180 -- loss 451.771498, RMS error 0.765490, R² score -0.194358 
Epoch 181 -- loss 458.962774, RMS error 0.765564, R² score -0.194801 
Epoch 182 -- loss 437.338595, RMS error 0.765482, R² score -0.194279 
Epoch 183 -- loss 436.141755, RMS error 0.765651, R² score -0.194697 
Epoch 184 -- loss 461.081025, RMS error 0.765496, R² score -0.194096 
Epoch 185 -- loss 429.291427, RMS error 0.765465, R² score -0.193912 
Epoch 186 -- loss 435.089542, RMS error 0.765582, R² score -0.194135 
Epoch 187 -- loss 436.786213, RMS error 0.765482, R² score -0.193974 
Epoch 188 -- loss 431.941343, RMS error 0.765575, R² score -0.193780 
Epoch 189 -- loss 421.752442, RMS error 0.765412, R² score -0.193634 
Epoch 190 -- loss 429.009824, RMS error 0.765514, R² score -0.193123 
Epoch 191 -- loss 416.687172, RMS error 0.765559, R² score -0.193684 
Epoch 192 -- loss 410.507160, RMS error 0.765617, R² score -0.193785 
Epoch 193 -- loss 421.352600, RMS error 0.765600, R² score -0.193762 
Epoch 194 -- loss 413.489459, RMS error 0.765647, R² score -0.192965 
Epoch 195 -- loss 402.512734, RMS error 0.765674, R² score -0.193274 
Epoch 196 -- loss 413.870134, RMS error 0.765627, R² score -0.192739 
Epoch 197 -- loss 394.342390, RMS error 0.765494, R² score -0.192513 
Epoch 198 -- loss 399.580948, RMS error 0.765423, R² score -0.191449 
Epoch 199 -- loss 410.817597, RMS error 0.765501, R² score -0.191298 
Epoch 200 -- loss 386.061726, RMS error 0.765495, R² score -0.191793 

myplot200,e5

Having more data always helps. Why don’t you try using data augmentation. That will even try to reduce overfitting(possibly).