Training loss improving but predictions getting worse

I have a very simple toy problem that is part of a greater research project. I am trying to prove that a single hidden-layer MLP can learn f(x) = x^2. We all know a shallow MLP can do this. The only caveat is that I am taking 64 different IID inputs, all drawn from a uniform distribution from [-5,+8]. Each of these 64 inputs gets squared, then we sum all 64 inputs to get our label y. As the inputs are IID, I am trying to show that they can all use a single input neuron into a single hidden unit MLP to learn x^2, since a common nonlinear operation is being performed on all inputs within the 64-dimensional input. Importantly, to de-bug, I have manually chosen parameters that are “close to optimal” for the network. This was by simply picking a linear combination of 14 ReLUs to build f(x)=x^2. These parameters are used for initialization.

Code with full working example (PyTorch Lightning but nothing specific within the problem to PTL):

import os
import torch
import pytorch_lightning as pl
from pytorch_lightning import LightningModule, Trainer
from torch import nn
import numpy as np
from torch.nn import functional as F
from torch.utils.data import DataLoader, TensorDataset
from torch.autograd import Variable
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.plugins import DDPPlugin
from pytorch_lightning.strategies.ddp import DDPStrategy
import torch.optim.lr_scheduler

os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
torch.cuda.empty_cache()

def gen_data(in_N,t,low,high):
    X = np.random.uniform(low,high,(t,in_N))
    u = np.sum(X**2,axis=1)
    return(X,u)

X,u = gen_data(64,100000,-5,8)
X_train = X[:90000,:]
X_train = torch.tensor(X_train)
X_validate = X[90000:,:]
X_validate = torch.tensor(X_validate)
Y_train = u[:90000]
Y_train = Y_train.reshape(90000,1)
Y_train = torch.tensor(Y_train)
Y_validate = u[90000:]
Y_validate = Y_validate.reshape(10000,1)
Y_validate = torch.tensor(Y_validate)

# Create dataset and dataloader for PyTorch
my_dataset_train = TensorDataset(X_train,Y_train)
my_dataset_val  = TensorDataset(X_validate,Y_validate)

class x2(LightningModule):
    def __init__(self):
        super().__init__()
        self.dropout_ff = torch.nn.Dropout(p=0.2)
        self.fc1 = torch.nn.Linear(1,14)
        self.fc2 = torch.nn.Linear(14,1)
        self.activation1 = torch.nn.ReLU()
        self.fc3 = torch.nn.Linear(64,1)

    def forward(self,x):
        x = x.float()
        batch_sz = x.shape[0]
        x = x.reshape(x.shape[0]*x.shape[1],1)
        x = self.dropout_ff(x)
        x = self.activation1(self.fc1(x))
        x = self.dropout_ff(x)
        x = self.fc2(x)
        x = x.reshape(batch_sz,64)
        x = self.dropout_ff(x)
        u_pred = self.fc3(x)
        return(u_pred)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y = y.float()
        loss = torch.nn.MSELoss(reduction='mean')(self(x), y)
        self.log("loss", loss, on_epoch=True, on_step=False, prog_bar=True, logger=True)
        return(loss)

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y = y.float()
        val_loss = torch.nn.MSELoss(reduction='mean')(self(x), y)
        self.log("val_loss", val_loss, on_epoch=True, on_step=False, sync_dist=True, prog_bar=True, logger=True)
        return(val_loss)

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        return(self(batch))

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-4)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size=20,gamma=0.9)
        return([optimizer], [scheduler])

def main():
    pl.seed_everything(350)

    # Init our model
    model = x2()

    # Initialize weights at manually created ReLU to fit x^2 over domain of interest
    model.fc1.weight.data = torch.tensor([[0.1],[-0.1],[0.5],[-0.5],[0.6],[-0.6],[1.5],[-1.5],[3.9],[-3.9],[2.0],[-2.0],[5.0],[1.0]])
    model.fc1.bias.data = torch.tensor([0,0,-0.05,-0.05,-0.24,-0.24,-1.35,-1.35,-8.58,-8.58,-7.8,-7.8,-27.5,-7.1])
    model.fc2.weight.data = torch.tensor([[1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0]])
    model.fc2.bias.data = torch.tensor([0.01])
    model.fc3.weight.data =  torch.tensor([[1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,\
    1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,\
    1.0,1.0,1.0,1.0]])
    model.fc3.bias.data = torch.tensor([-0.01])

    # Init DataLoader from MNIST Dataset
    train_loader = DataLoader(my_dataset_train, batch_size=100, shuffle = True, num_workers=24)
    val_loader = DataLoader(my_dataset_val, batch_size=100, num_workers=24)

    # Instantiate WandB loggers
    wandb_logger = WandbLogger(name="x2", project='x2')
    wandb_logger.watch(model, log="all", log_freq=10)

    # Init ModelCheckpoint callback, monitoring "val_loss"
    checkpoint_callback = ModelCheckpoint(dirpath='./',save_top_k=1,monitor="val_loss",mode='min',save_on_train_epoch_end=False,filename="x2",save_last=True)

    # Initialize a trainer
    trainer = Trainer(accelerator='gpu',strategy=DDPStrategy(find_unused_parameters=True),max_epochs=600,devices='auto',logger=wandb_logger,callbacks=[checkpoint_callback])

    # Train the model
    trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)

if __name__ == "__main__":
    main()

The issue that I am running into is that despite the MSE decreasing during training (on the training set), the actual predictions are getting worse, as exemplified by a plot of the predicted Y vs true Y from the model at the beginning of training (parameters closer to manually initialized “optimal parameters”) vs the same plot for the model after training.

After 1 epoch, where training loss is at peak but parameters are closest to “optimal”, predicted Y vs true Y for training data:
upred_ut_v214

At the end of training, where training loss is a nadir but parameters have drifted from “optimal” initialized parameters, the predicted Y vs true Y for training data:
upred_ut_v214_last

Interestingly, as this is toy, synthetic data, despite training and validation sets being IID, the loss dynamics during training look like this, with another model trained with He parameter initiation in blue:

None of this makes sense so I am trying to source ideas. I cannot come up with any reason as to how “learning” to minimize MSE is showing MSE decreasing during training, yet the MSE on the training set is clearly worse after training than before training (when parameters are close to “optimal” with manual init).

The question is not PyTorch-like but several points we can think

  1. The model is too big so it causes the overfitting
  2. Training set and val/test sets have the different distribution
  3. Loss function is improper to learn the defined task
  4. The dataset is not normalized

I recommend you to do 4

At only 108 total trainable parameters, I do not think it is (1), especially with dropout.
On (2), the training set and val/test have the exact same distribution. This is synthetic data that guarantees that.
For (3), MSE is the appropriate loss function.
I forgot to mention that I already tried normalizing the input data as you suggest in (4). There was no significant difference in any of the results.

The initial value of loss seems too large (1e6 scale)
Could you print out u_pred?

Here are the first 10 values for u_pred after the model is done training, for the training data:
[[968.4119 ]
[899.75665]
[878.20123]
[859.2173 ]
[910.5824 ]
[923.32886]
[923.5334 ]
[916.70667]
[864.8084 ]
[957.3488 ]]

min(u_pred) = 509 and max(u_pred)=1636

The code should be fully reproducible in current form if you have the latest version of PyTorch Lightning installed.

Interestingly on further trouble shooting, I’ve found that this may simply be an issue with PyTorch Lightning. When I train an identical model in PyTorch everything works as expected.