Output the same regardless of input after training MLP with one hidden layer

Hello, I implemented a MLP with one hidden layer in order to classify wine data into one of three categories (which cultivar produced it). I created the network and it runs, however, I am running into issues. After training on the training data I am getting the test output that is all approaching the same value regardless of input. It seems as though my network has learned the proportion of classes in the training dataset and is just outputting that? I am using a Cross-Entropy Loss and have my targets one-hot encoded. I have read that issues like these can be due to learning rate, biases, or model normalization but I am not sure which of these applies to my situation? Note that this model is in the early stages so many hyperparameters have not been tested (i.e. the number of hidden units) although I did try a few different learning rates to no avail.

Wine Data

Network code:

import torch
from torch import nn
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split

from keras.utils import to_categorical

class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        
        # Inputs to hidden linear combination
        self.hidden = nn.Linear(13, 25)
        # hidden to output layer, 3 classes - one for each cultivar
        self.output = nn.Linear(25, 3)
        
        # Defining activation functions
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        z1 = self.hidden(x)
        out1 = self.sigmoid(z1)
        z2 = self.output(out1)
        out2 = self.sigmoid(z2)
        
        return out2

Data preprocessing code:

# Importing data
data = np.loadtxt('D:\Documents\Code and Data\Data\wine.data', delimiter = ",", dtype = np.float32, skiprows = 1)
X = data[:,1:]
X = X / X.max(axis=0) # normalizing input data
y = data[:,[0]]
n_samples = data.shape[0]

# Splitting data randomly into train, crossval, and test sets
X_train, X_test, y_train, y_test  = train_test_split(X, y, test_size=0.2, random_state=1)

X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.1, random_state=1)

print("Number of samples in each dataset:")
print(X_train.shape[0])
print(X_test.shape[0])
print(X_val.shape[0])
print()

X_train = torch.from_numpy(X_train).cuda()
X_test = torch.from_numpy(X_test).cuda()
X_val = torch.from_numpy(X_val).cuda()

#one-hot encoding targets
y_train_encoded = to_categorical(y_train)[:,1:4]
y_test_encoded = to_categorical(y_test)[:,1:4]
y_val_encoded = to_categorical(y_val)[:,1:4]
print("example of one-hot encoded target data")
print(y_val_encoded)

y_train = torch.from_numpy(y_train_encoded).long().cuda()
y_test = torch.from_numpy(y_test_encoded).long().cuda()
y_val = torch.from_numpy(y_val_encoded).long().cuda()

Model training code:

model = Network()

device = torch.device("cuda:0")
model.to(device)
#X_train.to(device)

criterion = nn.CrossEntropyLoss() #cross-entropy loss

optimizer = torch.optim.SGD(model.parameters(), lr = 0.003) # implementing momentum for learning rate

model.eval()
y_pred = model(X_test)
before_train = criterion(y_pred.squeeze(), torch.max(y_test, 1)[1])
print("Test loss pre training: " + str(before_train.item()))

# Training model
epochs = 10000
for epoch in range(epochs):
    optimizer.zero_grad()
    
    output = model.forward(X_train)
    
    loss = criterion(output.squeeze(), torch.max(y_train, 1)[1]) 
    
    if epoch % 500 == 0:
        print('Epoch: {} train loss: {}'.format(epoch,loss.item()))
    
    loss.backward()
    optimizer.step()

Training output:

Test loss pre training: 1.0950652360916138
Epoch: 0 train loss: 1.1017485857009888
Epoch: 500 train loss: 1.095704436302185
Epoch: 1000 train loss: 1.0918421745300293
Epoch: 1500 train loss: 1.0893988609313965
Epoch: 2000 train loss: 1.087828516960144
Epoch: 2500 train loss: 1.0867860317230225
Epoch: 3000 train loss: 1.0860623121261597
Epoch: 3500 train loss: 1.0855324268341064
Epoch: 4000 train loss: 1.085121512413025
Epoch: 4500 train loss: 1.0847841501235962
Epoch: 5000 train loss: 1.0844918489456177
Epoch: 5500 train loss: 1.0842280387878418
Epoch: 6000 train loss: 1.0839811563491821
Epoch: 6500 train loss: 1.0837444067001343
Epoch: 7000 train loss: 1.0835126638412476
Epoch: 7500 train loss: 1.0832834243774414
Epoch: 8000 train loss: 1.0830546617507935
Epoch: 8500 train loss: 1.0828248262405396
Epoch: 9000 train loss: 1.0825930833816528
Epoch: 9500 train loss: 1.0823582410812378

Test Code:

model.eval()
y_pred = model(X_test)
after_train = criterion(y_pred.squeeze(), torch.max(y_test, 1)[1]) 
print("predictions: " + str(y_pred))
print("target array: " + str(y_test))
print('Test loss after Training' , after_train.item())

Test output:

predictions: tensor([[0.4429, 0.6842, 0.3533],
        [0.4493, 0.6928, 0.3438],
        [0.4559, 0.6864, 0.3451],
        [0.4557, 0.6917, 0.3467],
        [0.4561, 0.6837, 0.3427],
        [0.4406, 0.6829, 0.3536],
        [0.4414, 0.6887, 0.3462],
        [0.4609, 0.6829, 0.3408],
        [0.4443, 0.6800, 0.3530],
        [0.4455, 0.6904, 0.3460],
        [0.4573, 0.6860, 0.3449],
        [0.4492, 0.6893, 0.3433],
        [0.4509, 0.6908, 0.3420],
        [0.4620, 0.6854, 0.3405],
        [0.4546, 0.6923, 0.3423],
        [0.4506, 0.6910, 0.3422],
        [0.4516, 0.6875, 0.3502],
        [0.4565, 0.6869, 0.3425],
        [0.4405, 0.6894, 0.3443],
        [0.4617, 0.6828, 0.3431],
        [0.4575, 0.6867, 0.3435],
        [0.4532, 0.6894, 0.3397],
        [0.4546, 0.6897, 0.3491],
        [0.4555, 0.6911, 0.3414],
        [0.4575, 0.6861, 0.3400],
        [0.4399, 0.6846, 0.3522],
        [0.4616, 0.6849, 0.3410],
        [0.4568, 0.6863, 0.3436],
        [0.4562, 0.6862, 0.3426],
        [0.4418, 0.6794, 0.3565],
        [0.4536, 0.6903, 0.3425],
        [0.4507, 0.6790, 0.3497],
        [0.4539, 0.6801, 0.3546],
        [0.4602, 0.6830, 0.3407],
        [0.4524, 0.6881, 0.3416],
        [0.4549, 0.6892, 0.3444]], device='cuda:0', grad_fn=<SigmoidBackward>)
target array: tensor([[0, 0, 1],
        [0, 1, 0],
        [1, 0, 0],
        [0, 1, 0],
        [1, 0, 0],
        [0, 0, 1],
        [0, 1, 0],
        [1, 0, 0],
        [0, 0, 1],
        [0, 1, 0],
        [1, 0, 0],
        [1, 0, 0],
        [0, 1, 0],
        [1, 0, 0],
        [0, 1, 0],
        [0, 1, 0],
        [0, 0, 1],
        [1, 0, 0],
        [0, 1, 0],
        [1, 0, 0],
        [1, 0, 0],
        [0, 1, 0],
        [0, 0, 1],
        [0, 1, 0],
        [1, 0, 0],
        [0, 0, 1],
        [1, 0, 0],
        [1, 0, 0],
        [1, 0, 0],
        [0, 0, 1],
        [0, 1, 0],
        [0, 0, 1],
        [0, 0, 1],
        [1, 0, 0],
        [0, 1, 0],
        [0, 1, 0]], device='cuda:0')
Test loss after Training 1.088274598121643

Hi Daniel!

Don’t apply a Sigmoid to the output of your last `Linear’ layer.
That is, use, for example:

        return z2

CrossEntropyLoss wants the output you pass to it to be raw-score
logits that run from -inf to inf. The Sigmoid compresses these
values to run from 0.0 to 1.0 (so they look like probabilities), and,
in effect, you would have to train your network to undo the damage
caused by the Sigmoid – something that will be hard (impossible?)
to do with a single-hidden-layer network.

I assume that this line is converting your targets to one-hot format.

But the torch.max(y_train, 1)[1] that you pass to criterion
converts your one-hot encoded targets back to integer class labels.
This is correct for CrossEntropyLoss (as CrossEntropyLoss does
not work with one-hot targets), but it seems unnecessary to convert
your targets to one-hot, and then back again.

(As an aside, torch.max()[1] is an idiom for torch.argmax().
I believe your version of pytorch is new enough that it supports
torch.argmax(), so it would likely be a bit more readable to use
that.)

Best.

K. Frank

So, removing the sigmoid in the output layer worked to discriminate between the classes! My output obviously is no longer on a 0:1 scale it can now become negative and above 1 but I dont think this matters much since taking the maximum of the output will still return the correct prediction on the test data. Could you expand more on how to use the Cross-Entropy Loss without one-hot encoding? I thought that it was necessary for multi-class problems? My current targets are integers ranging from 1:3.

Hi Daniel!

Yes, this is correct. Taking the argmax() of the output of your final
Linear layer gives you the integer label of the class that your model
is predicting to be the most likely. (Just to be precise, you are taking
the argmax() rather than the max()). The fact that you are applying
argmax() to logits instead of probabilities doesn’t change the fact
that you are determining the label of the most likely class.

No, one-hot encoding is not necessary. Leaving pytorch aside, you
might elect to specify your labelled classes with one-hot encoding or
with integer class labels (or some other way). It’s the same information,
just packaged differently.

Pytorch’s CrossEntropyLoss requires that you pass in integer class
labels for its target. (Note, in your code you are passing in integer
class labels – you convert your one-hot labels to integer labels right
in your call to criterion().)

For CrossEntropyLoss you want your integer class labels to range
over {0, 1, 2} (so you can simply subtract 1 from your current labels).
In general, CrossEntropyLoss takes class labels that range over
[0, nClass - 1].

As an aside, I don’t know precisely what your:

y_train_encoded = to_categorical(y_train)[:,1:4]

is doing, but I assume that it is converting your initial 1-2-3 integer
labels to one-hot encoding. Then your torch.max(y_train, 1)[1]
(in effect, argmax()) is converting your one-hot labels to 0-1-2
integer labels. Of course, you can more simply convert 1-2-3 labels
to 0-1-2 labels by subtracting 1.

Best.

K. Frank