Neural network not learning at all

I am training a MLP on a tabular dataset, the pendigits dataset. Problem is that training loss and accuracy are more or less stable, while validation and test loss and accuracy are completely constant. The pendigits dataset contains 10 classes. My code is exactly the same with other experiments that I did for example on MNIST or CIFAR10 that work correctly. The only things that change are the dataset from MNIST/CIFAR10 to pendigits and the NN, from a ResNet-18 to a simple MLP. Below the training function and the network:

def train(net, loaders, optimizer, criterion, epochs=100, dev=dev, save_param = True, model_name="only-pendigits"):
    torch.manual_seed(myseed)
    try:
        net = net.to(dev)
        print(net)
        # Initialize history
        history_loss = {"train": [], "val": [], "test": []}
        history_accuracy = {"train": [], "val": [], "test": []}
        # Process each epoch
        for epoch in range(epochs):
            # Initialize epoch variables
            sum_loss = {"train": 0, "val": 0, "test": 0}
            sum_accuracy = {"train": 0, "val": 0, "test": 0}
            # Process each split
            for split in ["train", "val", "test"]:
                # Process each batch
                for (input, labels) in loaders[split]:
                    # Move to CUDA
                    input = input.to(dev)
                    labels = labels.to(dev)
                    # Reset gradients
                    optimizer.zero_grad()
                    # Compute output
                    pred = net(input)
                    #labels = labels.long()
                    loss = criterion(pred, labels)
                    # Update loss
                    sum_loss[split] += loss.item()
                    # Check parameter update
                    if split == "train":
                        # Compute gradients
                        loss.backward()
                        # Optimize
                        optimizer.step()
                    # Compute accuracy
                    _,pred_labels = pred.max(1)
                    batch_accuracy = (pred_labels == labels).sum().item()/input.size(0)
                    # Update accuracy
                    sum_accuracy[split] += batch_accuracy
                scheduler.step()
            # Compute epoch loss/accuracy
            epoch_loss = {split: sum_loss[split]/len(loaders[split]) for split in ["train", "val", "test"]}
            epoch_accuracy = {split: sum_accuracy[split]/len(loaders[split]) for split in ["train", "val", "test"]}
            # Update history
            for split in ["train", "val", "test"]:
                history_loss[split].append(epoch_loss[split])
                history_accuracy[split].append(epoch_accuracy[split])
            # Print info
            print(f"Epoch {epoch+1}:",
                  f"TrL={epoch_loss['train']:.4f},",
                  f"TrA={epoch_accuracy['train']:.4f},",
                  f"VL={epoch_loss['val']:.4f},",
                  f"VA={epoch_accuracy['val']:.4f},",
                  f"TeL={epoch_loss['test']:.4f},",
                  f"TeA={epoch_accuracy['test']:.4f},",
                  f"LR={optimizer.param_groups[0]['lr']:.5f},")
    except KeyboardInterrupt:
        print("Interrupted")
    finally:
        # Plot loss
        plt.title("Loss")
        for split in ["train", "val", "test"]:
            plt.plot(history_loss[split], label=split)
        plt.legend()
        plt.show()
        # Plot accuracy
        plt.title("Accuracy")
        for split in ["train", "val", "test"]:
            plt.plot(history_accuracy[split], label=split)
        plt.legend()
        plt.show()

Network:

#RETE TESTO
class TextNN(nn.Module):

    #Constructor
    def __init__(self):
    # Call parent contructor
        super().__init__()
        torch.manual_seed(myseed)
        self.relu = nn.ReLU()
        self.linear1 = nn.Linear(16, 128) #16 sono le colonne in input
        self.linear2 = nn.Linear(128, 128)
        self.linear3 = nn.Linear(128, 32)
        self.linear4 = nn.Linear(32, 10)
    
    def forward(self, tab):
        tab = self.linear1(tab)
        tab = self.relu(tab)
        tab = self.linear2(tab)
        tab = self.relu(tab)
        tab = self.linear3(tab)
        tab = self.relu(tab)
        tab = self.linear4(tab)

        return tab

model = TextNN()
print(model)

Is it possible that the model is too simple that it does not learn anything? I do not think so. I think that there is some error in training (but the function is exactly the same with the function I use for MNIST or CIFAR10 that works correctly), or in the data loading. Below is how I load the dataset:

pentrain = pd.read_csv("pendigits.tr.csv")
pentest = pd.read_csv("pendigits.te.csv")

class TextDataset(Dataset):
    """Tabular and Image dataset."""

    def __init__(self, excel_file, transform=None):
        self.excel_file = excel_file
        #self.tabular = pd.read_csv(excel_file)
        self.tabular = excel_file
        self.transform = transform

    def __len__(self):
        return len(self.tabular)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        tabular = self.tabular.iloc[idx, 0:]

        y = tabular["class"]


        tabular = tabular[['input1', 'input2', 'input3', 'input4', 'input5', 'input6', 'input7',
       'input8', 'input9', 'input10', 'input11', 'input12', 'input13',
       'input14', 'input15', 'input16']]
        tabular = tabular.tolist()
        tabular = torch.FloatTensor(tabular)
        
        if self.transform:
            tabular = self.transform(tabular)

        return tabular, y

penditrain = TextDataset(excel_file=pentrain, transform=None)

train_size = int(0.80 * len(penditrain))
val_size = int((len(penditrain) - train_size))

pentrain, penval = random_split(penditrain, (train_size, val_size))

pentest = TextDataset(excel_file=pentest, transform=None)

All is loaded correctly, indeed if I print an example:

text_x, label_x = pentrain[0]
print(text_x.shape, label_x)
text_x

torch.Size([16]) 1
tensor([ 48.,  74.,  88.,  95., 100., 100.,  78.,  75.,  66.,  49.,  64.,  23.,
         32.,   0.,   0.,   1.])

And these are my dataloaders:

#Define generators
generator=torch.Generator()
generator.manual_seed(myseed)

# Define loaders
from torch.utils.data import DataLoader
train_loader = DataLoader(pentrain, batch_size=128, num_workers=2, drop_last=True, shuffle=True, generator=generator)
val_loader   = DataLoader(penval,   batch_size=128, num_workers=2, drop_last=False, shuffle=False, generator=generator)
test_loader  = DataLoader(pentest,  batch_size=128, num_workers=2, drop_last=False, shuffle=False, generator=generator)

I have been stuck with this problem for 2 days, and I do not know what the problem is…

The dataset seems to use integers as data inputs. For this reason, I would think a transformer with an embedding layer would be more appropriate, followed by a Linear layer to get your class predictions.

You could try Linear layers, but might want to normalize your input data to be in between 0 and 1 or -1 to 1.

I already tried with normalization, but anything changes. Values are different, but they remain stable. I think that there is a problem because network does never change values of loss or accuracy

What loss function are you using? What optimizer and hyperparameters?

These are my optimized, loss, and scheduler. Also scheduling is working as I want, indeed it starts with a low value, then it increases, and at the end of training it decreases.

# Define an optimizier
import torch.optim as optim
optimizer = optim.Adam([x for x in model.parameters() if x.requires_grad], lr=0.0001)
# Define a loss 
criterion = nn.CrossEntropyLoss()
#scheduler
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr = 0.01, epochs=2, steps_per_epoch=312)

Can you try printing the model inputs and check if they are what you need?

I’m a bit confused with the above since this was a print out of label_x, yet, the label should only be 1 integer denoting the class. And the above has 16 elements.

One more possible issue, have you tried wrapping y in torch.tensor(y)?

Above are the input features. Pendigits dataset has 16 input features (“input1”, “input2”, …, “input16”) and a label (“class”).
Let’s check:

pentrain.columns

Index(['input1', 'input2', 'input3', 'input4', 'input5', 'input6', 'input7',
       'input8', 'input9', 'input10', 'input11', 'input12', 'input13',
       'input14', 'input15', 'input16', 'class'],
      dtype='object')

This is a sample from pentrain:

pentrain[0]

(tensor([ 48.,  74.,  88.,  95., 100., 100.,  78.,  75.,  66.,  49.,  64.,  23.,
          32.,   0.,   0.,   1.]),
 1)

You can see that firstly are printed the features, and then it is printed the label, because my dataset class (TextDataset()) returns features and label.

I have also made this change, but anything changes.

Basically, if I write print(list(net.parameters())) at the beginning of each epoch, I see that weights does never change, and for this reason loss and accuracy remain constant. Why weights are not changing?

How balanced is are the dataset classes? Also, can you print the epoch loss and accuracy for the first 10 epochs?

Classes are perfectly balanced. This is the train set:

Class    N. of samples
  2          780
  4          780
  0          780
  1          779
  7          778
  6          720
  5          720
  8          719
  9          719
3    719

Now I am using a simpler MLP (just to speed up things):

Epoch 1: TrL=3.9838, TrA=0.1578, VL=3.9959, VA=0.1623, TeL=4.0766, TeA=0.1564, LR=0.00041,
Epoch 2: TrL=3.9803, TrA=0.1578, VL=3.9959, VA=0.1623, TeL=4.0766, TeA=0.1564, LR=0.00042,
Epoch 3: TrL=3.9815, TrA=0.1581, VL=3.9959, VA=0.1623, TeL=4.0766, TeA=0.1564, LR=0.00046,
Epoch 4: TrL=3.9900, TrA=0.1574, VL=3.9959, VA=0.1623, TeL=4.0766, TeA=0.1564, LR=0.00050,
Epoch 5: TrL=3.9904, TrA=0.1579, VL=3.9959, VA=0.1623, TeL=4.0766, TeA=0.1564, LR=0.00055,
Epoch 6: TrL=3.9784, TrA=0.1591, VL=3.9959, VA=0.1623, TeL=4.0766, TeA=0.1564, LR=0.00062,
Epoch 7: TrL=3.9828, TrA=0.1586, VL=3.9959, VA=0.1623, TeL=4.0766, TeA=0.1564, LR=0.00070,
Epoch 8: TrL=3.9865, TrA=0.1583, VL=3.9959, VA=0.1623, TeL=4.0766, TeA=0.1564, LR=0.00079,
Epoch 9: TrL=3.9782, TrA=0.1591, VL=3.9959, VA=0.1623, TeL=4.0766, TeA=0.1564, LR=0.00089,
Epoch 10: TrL=3.9839, TrA=0.1583, VL=3.9959, VA=0.1623, TeL=4.0766, TeA=0.1564, LR=0.00100,
Epoch 11: TrL=3.9860, TrA=0.1585, VL=3.9959, VA=0.1623, TeL=4.0766, TeA=0.1564, LR=0.00112,
Epoch 12: TrL=3.9781, TrA=0.1583, VL=3.9959, VA=0.1623, TeL=4.0766, TeA=0.1564, LR=0.00126,
Epoch 13: TrL=3.9837, TrA=0.1578, VL=3.9959, VA=0.1623, TeL=4.0766, TeA=0.1564, LR=0.00140,
Epoch 14: TrL=3.9778, TrA=0.1583, VL=3.9959, VA=0.1623, TeL=4.0766, TeA=0.1564, LR=0.00156,
Epoch 15: TrL=3.9838, TrA=0.1571, VL=3.9959, VA=0.1623, TeL=4.0766, TeA=0.1564, LR=0.00172,
Epoch 16: TrL=3.9851, TrA=0.1571, VL=3.9959, VA=0.1623, TeL=4.0766, TeA=0.1564, LR=0.00189,
Epoch 17: TrL=3.9794, TrA=0.1579, VL=3.9959, VA=0.1623, TeL=4.0766, TeA=0.1564, LR=0.00207,
Epoch 18: TrL=3.9773, TrA=0.1579, VL=3.9959, VA=0.1623, TeL=4.0766, TeA=0.1564, LR=0.00226,
Epoch 19: TrL=3.9830, TrA=0.1585, VL=3.9959, VA=0.1623, TeL=4.0766, TeA=0.1564, LR=0.00245,
Epoch 20: TrL=3.9766, TrA=0.1576, VL=3.9959, VA=0.1623, TeL=4.0766, TeA=0.1564, LR=0.00266,

Here TrL and TrA are training loss and accuracies, and so on for Val and Test metrics. LR is the learning rate, that is correctly changing according to the scheduler

An MLP is likely not an ideal model for this dataset.

In an image, you have values at each pixel. But here, you have each pixel assigned an integer as pressed in some sequential order. Linear layers tend not to handle sequential or positional information very well. This dataset seems better suited for either a Conv1d with proper normalization and trainable embedding(could just be an embedding_dim=2), or a Transformer(best) with embedding. Even an RNN might be better than linear layers in this case.

1 Like

I think that your opinion is really interesting. I think I will try with a different model like a RNN, as you suggested. However, I am not interested in training on pendigits. I want to train an MLP on a tabular dataset with 10 classes, and I started experimenting on pendigits because it is a famous UCI tabular dataset with 10 classes. Do you know any other famous toy tabular datasets with 10 classes? I underline tabular and 10 classes, so not MNIST or CIFAR10 for example. Moreover, any other famous dataset with more classes is appreciated because I can remove the additional classes in order to have a 10 classes dataset.

Here is one:

Note that all of the vectors in the above dataset are orthogonal and do not have a sequential dependency.

Here is another:

Thank you for the reply. However, these datasets are not good for me, because titanic is a 2-class dataset, while House is a regression problem. I will search other datasets.

I found this: Handwritten Digit Recognition with scikit-learn
It uses the digits dataset of scikit-learn rather than the pendigits dataset. However, also in this case, with my code accuracy is stable at 10%, and loss does not decreases. This link instead shows that an MLP classifier is able to classify this dataset, and that just after few epochs values are good.

I ran the following snippet without any issues and it showed decreasing loss when overfit to preset data:

import torch
import torch.nn as nn


class TextNN(nn.Module):

    # Constructor
    def __init__(self):
        # Call parent contructor
        super().__init__()
        self.relu = nn.ReLU()
        self.linear1 = nn.Linear(16, 128)  # 16 sono le colonne in input
        self.linear2 = nn.Linear(128, 128)
        self.linear3 = nn.Linear(128, 32)
        self.linear4 = nn.Linear(32, 10)

    def forward(self, tab):
        tab = self.linear1(tab)
        tab = self.relu(tab)
        tab = self.linear2(tab)
        tab = self.relu(tab)
        tab = self.linear3(tab)
        tab = self.relu(tab)
        tab = self.linear4(tab)

        return tab


model = TextNN()

import torch.optim as optim
optimizer = optim.Adam([x for x in model.parameters() if x.requires_grad], lr=0.0001)
# Define a loss
criterion = nn.CrossEntropyLoss()
#scheduler
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr = 0.01, epochs=2, steps_per_epoch=312)

dummy_input = torch.randint(0, 100, (128,16))
dummy_labels = torch.randint(0, 10, (128,))

def normalize(x):
    return x/100.

dummy_input=normalize(dummy_input)

while True:
    optimizer.zero_grad()
    output=model(dummy_input)

    loss = criterion(output, dummy_labels)
    loss.backward()
    optimizer.step()
    print(loss.item())

Yes, I also have a function to normalize my data. However, my point is that weights are not changing. Even if I agree with your previous observation about the sequential order and the positional information of the data, I can not understand why weights are completely stable. At least, even if neural network is bad, weights need to change, but this not happens, leading me to think that there is something strange or in my train function or in how the data are passed to the model.

It should still improve, even if it’s not the ideal model. I’m not sure what the issue is. But I do note your accuracy is higher than 10%, which is what it should be statistically if there are 10 classes with balanced data and random guessing.

1 Like