I have a very simple toy problem that is part of a greater research project. I am trying to prove that a single hidden-layer MLP can learn f(x) = x^2. We all know a shallow MLP can do this. The only caveat is that I am taking 64 different IID inputs, all drawn from a uniform distribution from [-5,+8]. Each of these 64 inputs gets squared, then we sum all 64 inputs to get our label y. As the inputs are IID, I am trying to show that they can all use a single input neuron into a single hidden unit MLP to learn x^2, since a common nonlinear operation is being performed on all inputs within the 64-dimensional input. Importantly, to de-bug, I have manually chosen parameters that are “close to optimal” for the network. This was by simply picking a linear combination of 14 ReLUs to build f(x)=x^2. These parameters are used for initialization.
Code with full working example (PyTorch Lightning but nothing specific within the problem to PTL):
import os import torch import pytorch_lightning as pl from pytorch_lightning import LightningModule, Trainer from torch import nn import numpy as np from torch.nn import functional as F from torch.utils.data import DataLoader, TensorDataset from torch.autograd import Variable from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.plugins import DDPPlugin from pytorch_lightning.strategies.ddp import DDPStrategy import torch.optim.lr_scheduler os.environ['CUDA_LAUNCH_BLOCKING'] = "1" torch.cuda.empty_cache() def gen_data(in_N,t,low,high): X = np.random.uniform(low,high,(t,in_N)) u = np.sum(X**2,axis=1) return(X,u) X,u = gen_data(64,100000,-5,8) X_train = X[:90000,:] X_train = torch.tensor(X_train) X_validate = X[90000:,:] X_validate = torch.tensor(X_validate) Y_train = u[:90000] Y_train = Y_train.reshape(90000,1) Y_train = torch.tensor(Y_train) Y_validate = u[90000:] Y_validate = Y_validate.reshape(10000,1) Y_validate = torch.tensor(Y_validate) # Create dataset and dataloader for PyTorch my_dataset_train = TensorDataset(X_train,Y_train) my_dataset_val = TensorDataset(X_validate,Y_validate) class x2(LightningModule): def __init__(self): super().__init__() self.dropout_ff = torch.nn.Dropout(p=0.2) self.fc1 = torch.nn.Linear(1,14) self.fc2 = torch.nn.Linear(14,1) self.activation1 = torch.nn.ReLU() self.fc3 = torch.nn.Linear(64,1) def forward(self,x): x = x.float() batch_sz = x.shape x = x.reshape(x.shape*x.shape,1) x = self.dropout_ff(x) x = self.activation1(self.fc1(x)) x = self.dropout_ff(x) x = self.fc2(x) x = x.reshape(batch_sz,64) x = self.dropout_ff(x) u_pred = self.fc3(x) return(u_pred) def training_step(self, batch, batch_idx): x, y = batch y = y.float() loss = torch.nn.MSELoss(reduction='mean')(self(x), y) self.log("loss", loss, on_epoch=True, on_step=False, prog_bar=True, logger=True) return(loss) def validation_step(self, batch, batch_idx): x, y = batch y = y.float() val_loss = torch.nn.MSELoss(reduction='mean')(self(x), y) self.log("val_loss", val_loss, on_epoch=True, on_step=False, sync_dist=True, prog_bar=True, logger=True) return(val_loss) def predict_step(self, batch, batch_idx, dataloader_idx=0): return(self(batch)) def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=1e-4) scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size=20,gamma=0.9) return([optimizer], [scheduler]) def main(): pl.seed_everything(350) # Init our model model = x2() # Initialize weights at manually created ReLU to fit x^2 over domain of interest model.fc1.weight.data = torch.tensor([[0.1],[-0.1],[0.5],[-0.5],[0.6],[-0.6],[1.5],[-1.5],[3.9],[-3.9],[2.0],[-2.0],[5.0],[1.0]]) model.fc1.bias.data = torch.tensor([0,0,-0.05,-0.05,-0.24,-0.24,-1.35,-1.35,-8.58,-8.58,-7.8,-7.8,-27.5,-7.1]) model.fc2.weight.data = torch.tensor([[1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0]]) model.fc2.bias.data = torch.tensor([0.01]) model.fc3.weight.data = torch.tensor([[1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,\ 1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,\ 1.0,1.0,1.0,1.0]]) model.fc3.bias.data = torch.tensor([-0.01]) # Init DataLoader from MNIST Dataset train_loader = DataLoader(my_dataset_train, batch_size=100, shuffle = True, num_workers=24) val_loader = DataLoader(my_dataset_val, batch_size=100, num_workers=24) # Instantiate WandB loggers wandb_logger = WandbLogger(name="x2", project='x2') wandb_logger.watch(model, log="all", log_freq=10) # Init ModelCheckpoint callback, monitoring "val_loss" checkpoint_callback = ModelCheckpoint(dirpath='./',save_top_k=1,monitor="val_loss",mode='min',save_on_train_epoch_end=False,filename="x2",save_last=True) # Initialize a trainer trainer = Trainer(accelerator='gpu',strategy=DDPStrategy(find_unused_parameters=True),max_epochs=600,devices='auto',logger=wandb_logger,callbacks=[checkpoint_callback]) # Train the model trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader) if __name__ == "__main__": main()
The issue that I am running into is that despite the MSE decreasing during training (on the training set), the actual predictions are getting worse, as exemplified by a plot of the predicted Y vs true Y from the model at the beginning of training (parameters closer to manually initialized “optimal parameters”) vs the same plot for the model after training.
After 1 epoch, where training loss is at peak but parameters are closest to “optimal”, predicted Y vs true Y for training data:
At the end of training, where training loss is a nadir but parameters have drifted from “optimal” initialized parameters, the predicted Y vs true Y for training data:
Interestingly, as this is toy, synthetic data, despite training and validation sets being IID, the loss dynamics during training look like this, with another model trained with He parameter initiation in blue:
None of this makes sense so I am trying to source ideas. I cannot come up with any reason as to how “learning” to minimize MSE is showing MSE decreasing during training, yet the MSE on the training set is clearly worse after training than before training (when parameters are close to “optimal” with manual init).