Loss and accuracy stuck, very low gradient

Hello,

I’m was making a neural network to to try a few option of pytorch, and then when I tried it on the classical Breast Cancer dataset my algo was just stuck. I simplified to the maximum my model, but accurracy seem to be blocked.

my model:

class Modela(torch.nn.Module):
    def __init__(self):
        super(Modela, self).__init__()
        self.l1 = torch.nn.Linear(30,16)
        self.l2 = torch.nn.Linear(16,4)
        self.l3 = torch.nn.Linear(4,2)
        
        self.sigmoid = torch.nn.Sigmoid()
        self.relu = torch.nn.ReLU()
        
    def forward(self, x):
        out1 = self.relu(self.l1(x)) 
        out2 = self.relu(self.l2(out1))
        y_pred = self.sigmoid(self.l3(out2))
        return y_pred

my training:

train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])



model = Modela()
print(model)
#TRAIN
n_epoch = 50
criterion = torch.nn.CrossEntropyLoss()#torch.nn.BCELoss(reduction='sum')

        

        
def train(model,train_dataset,test_dataset,loss_fn,n_epoch=50):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
    valid_losses = [] #list of the losses on valid
    valid_acc = [] #list of the accuracy on valid
    train_losses = [] #list of all the loss on training 
    train_loss_store = torch.zeros(1) #will store the current accuracy
    for e in range(n_epoch):
        print('n_epoch',e)
        for i in range(len(train_dataset)):

            #let's evaluate the performan,ce of my model
            if (i%100 == 0):

                score = torch.zeros(1)
                valid_loss_store = torch.zeros(1)
                
                with torch.no_grad():
                    for j in range(len(test_dataset)):
                        sample = test_dataset[j]
                        X, label = sample['X'], sample['label']
                        output = model(X)
                        output = torch.unsqueeze(output,0)
                        label = torch.unsqueeze(label,0)
                        loss = loss_fn(output, label)
                        valid_loss_store += loss
                        _, predicted = torch.max(output.data, 1)
                        score += predicted.eq(label).float()
                #let's print a few result
                score = score/len(test_dataset)
                valid_acc += [score.item()]
                valid_loss_store = valid_loss_store/len(test_dataset)
                valid_losses += [valid_loss_store.item()]
                train_loss_store = train_loss_store/100
                print('acc', score, 'valid loss :',valid_loss_store, 'train_loss_store', train_loss_store )
                train_losses += [train_loss_store.item()]
                train_loss_store = torch.zeros(1)
                    

   
            # here we train on the training dataset        
            sample = train_dataset[i]
            X, label = sample['X'], sample['label']
            optimizer.zero_grad()
            output = model(X)

            output = torch.unsqueeze(output,0)
            label = torch.unsqueeze(label,0)
            loss = loss_fn(output, label)
            train_loss_store += loss
            loss.backward()
#             for param in model.parameters():
#                 print(i,'gradient odg',param.grad.data.sum())
            optimizer.step()
    return train_losses, valid_losses, valid_acc
           
train_losses, valid_losses, valid_acc = train(model,train_dataset, test_dataset,criterion)

If I start this last part, sometimes I quickly grow to a 0.9 accurracy but most of the time I m just stuck to the exact first accuracy, with gradient close from zero. It looks like a local minimum maybe but in that case the fact of making a step for each data should enables it to escape this minimum. I don’t understand where it does come from.
I’ve tried to change the learning rate, and it does not change much.

Best regards

b

Hi Barthelemy!

Let me speculate that you are building a binary classification
model – i.e., some sample either is or is not breast cancer, based
on some kind of input features taken from a breast-tissue sample.

(Your last layer, self.l3, has two outputs, suggesting that this
may be binary.)

I think your main problem is that you are calling self.sigmoid()
on the output of your last layer. Given that (from your code further
down) you are using nn.CrossEntropyLoss as your loss function,
you should be using the outputs of l3 as your y_pred and pass
them directly to nn.CrossEntropyLoss. This is because (maybe
a little bit inconsistently with its name) CrossEntropyLoss expects
logits (numbers running from -infinity to infinity) as its inputs, but
your sigmoid maps the logits coming out of l3 to (0, 1).

So try getting rid of the sigmoid, and see if that helps.

A couple more comments:

If this is a binary classification problem, you should consider
having your last layer have a single output, and using
nn.BCEWithLogitsLoss as your loss function. This won’t
change the actual math or results of your network, but will
make it a little simpler and more efficient.

Also, it’s not clear that having three layers / two hidden layers
makes your network better. It could make it worse or harder
to train. You might try a single hidden layer with something
like 16 hidden neurons (as you have in your first hidden layer).

So your code fragment (also getting rid of the sigmoid and
changing to the binary version instead of the general multi-class
version) might look something like:

class Modela(torch.nn.Module):
    def __init__(self):
        super(Modela, self).__init__()
        self.l1 = torch.nn.Linear(30,16)
        self.l2 = torch.nn.Linear(16,1)

        self.relu = torch.nn.ReLU()
        
    def forward(self, x):
        hidden1 = self.relu(self.l1(x)) 
        y_pred = self.l2(hidden1)
        return y_pred

# and to match the single output of l2
criterion = torch.nn.BCEWithLogitsLoss()

Whether a single hidden layer will work better than two will be
problem dependent, but it’s likely to, so it’s worth trying. Also,
the best number of hidden neurons will be problem dependent,
so you should experiment some, but, given that you have 30
inputs (and one output), something like 16 hidden neurons is
a reasonable place to start.

Best.

K. Frank

4 Likes