Why is my custom LSTM model not Early Stopping?

patience = 500
Crosval_num = 5
Accuracy_cv, Accuracy_svc = [], []
Recall_cv, Recall_svc = [], []
Precise_cv, Precise_svc = [], []
F1_cv, F1_svc = [], []
Asm_cv, Asm2_cv = [], []
Xtrainval, ytrainval = np.concatenate((Xtrain,Xtest)), np.concatenate((ytrain, ytest))
Ntrain = Xtrainval.shape[0]
Index_trval = list(range(Xtrainval.shape[0]))
labda = 0.5
Accuracy_rnn, Recall_rnn, Precise_rnn, F1_rnn = [], [], [], []

# initialize the early_stopping object
early_stopping = EarlyStopping(patience=patience, verbose=True)

for cv_i in range(Crosval_num):
    Index_val = list(range(cv_i*Ntrain//Crosval_num, (cv_i+1)*Ntrain//Crosval_num))
    Index_train = list(set(Index_trval).difference(set(Index_val)))
    Xtrain, Xtest = Xtrainval[Index_train,:], Xtrainval[Index_val,:]
    ytrain, ytest = ytrainval[Index_train], ytrainval[Index_val]
    ytrain[ytrain==-1] = 0
    ytest[ytest == -1] = 0
    Xtrain = torch.Tensor(Xtrain)
    Xtest = torch.Tensor(Xtest)
    ytrain, ytest = torch.LongTensor(ytrain), torch.LongTensor(ytest)
    N, T = Xtrain.shape[0], Xtrain.shape[1]
    print(N,T)
    batch_size = N//10
    time_step = T      
    input_size = 28     
    hidden_size = 28
    num_layers = 2
    num_classes = 2
    lr = 0.1         
#     dropout = 0.5
    n_epochs = 2000
    model = simpleLSTM(input_size, hidden_size, num_layers, num_classes)
    # loss and optimizer
    loss_all, accu_all, auc_all, f1_all, recall_all, precise_all = [], [], [], [], [], []
    # to track the training loss as the model trains
    train_losses = []
    # to track the validation loss as the model trains
    valid_losses = []
    # to track the average training loss per epoch as the model trains
    avg_train_losses = []
    # to track the average validation loss per epoch as the model trains
    avg_valid_losses = [] 
    
    for epoch in range(n_epochs):
        for i in range(1):
            sp_idx = np.random.randint(0, Xtrain.shape[0], batch_size)
            images = Xtrain[sp_idx, :,:]
            labels = ytrain[sp_idx]
            images = images.reshape(-1, time_step, input_size) 
            
            # forward pass
            outputs, _= model(images) 
            loss = criterion(outputs, labels)
            
            # backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_losses.append(loss.item())

            outputs, outputs_prob = model(Xtest) #TLNN
            test_loss = criterion(outputs, ytest)
            print("Validation loss", test_loss) #val loss
            
            _, predicted = torch.max(outputs_prob.data, 1)
            predicted_prob = outputs_prob[:,1]
            accu = accuracy_score(ytest, predicted)
            cof_mat = confusion_matrix(predicted, ytest)
            
            try:
                accu_all.append(accu)
                auc_all.append(roc_auc_score(ytest.detach().numpy(), predicted_prob.detach().numpy()))
                f1_all.append(f1_score(ytest.detach().numpy(), predicted))
                recall_all.append(recall_score(ytest, predicted))
                precise_all.append(precision_score(ytest, predicted))
                
            except:
                print('excepted')
                continue
        train_loss = np.average(train_losses)
        valid_loss = np.average(valid_losses)
        avg_train_losses.append(train_loss)
        avg_valid_losses.append(valid_loss)
        epoch_len = len(str(n_epochs))

        print_msg = (f'[{epoch:>{epoch_len}}/{n_epochs:>{epoch_len}}] ' +
                     f'train_loss: {train_loss:.5f} ' +
                     f'valid_loss: {valid_loss:.5f}')

        print(print_msg)

        # clear lists to track next epoch
        train_losses = []
        valid_losses = []

        # early_stopping needs the validation loss to check if it has decresed, 
        # and if it has, it will make a checkpoint of the current model
        early_stopping(valid_loss, model)
        
        

        if early_stopping.early_stop:
            print("Reached")
            print("Early stopping")
            break

    print("load the last checkpoint with the best model")
    model.load_state_dict(torch.load('checkpoint.pt'))
    f1max_idx = f1_all.index(max(f1_all))
    f1max_idx = -1
    Accuracy_rnn.append(accu_all[f1max_idx])
    Recall_rnn.append(recall_all[f1max_idx])
    Precise_rnn.append(precise_all[f1max_idx])
    F1_rnn.append(f1_all[f1max_idx])
    
Accuracy_mean = np.mean(Accuracy_rnn)
Recall_mean = np.mean(Recall_rnn)
Precise_mean = np.mean(Precise_rnn)
F1_mean = np.mean(F1_rnn)  
print("Result for TL_NN is {} \t {} \t {} \t {}".format(Accuracy_mean, Recall_mean, Precise_mean, F1_mean))

Output:

[1995/2000] train_loss: 0.85707 valid_loss: nan
Validation loss decreased (nan --> nan).  Saving model ...
Validation loss tensor(0.9948, grad_fn=<NllLossBackward0>)
[1996/2000] train_loss: 0.85705 valid_loss: nan
Validation loss decreased (nan --> nan).  Saving model ...
Validation loss tensor(0.9948, grad_fn=<NllLossBackward0>)
[1997/2000] train_loss: 0.41620 valid_loss: nan
Validation loss decreased (nan --> nan).  Saving model ...
Validation loss tensor(0.9948, grad_fn=<NllLossBackward0>)
[1998/2000] train_loss: 0.85705 valid_loss: nan
Validation loss decreased (nan --> nan).  Saving model ...
Validation loss tensor(0.9948, grad_fn=<NllLossBackward0>)
[1999/2000] train_loss: 0.85706 valid_loss: nan
Validation loss decreased (nan --> nan).  Saving model ...
load the last checkpoint with the best model
Result for TL_NN is 0.38571428571428573 	 0.8 	 0.2857142857142857 	 0.4068686868686869

This is the training and testing code for training a custom LSTM model with different gating mechanism. Any suggestions on improving the model accuracy would be appreciated @ptrblck

Your current validation loss is getting NaN values, so I would be more concerned about debugging this issue than trying to improve the loss.
Check the validation predictions as well as the targets and try to narrow down what’s creating the invalid loss values.

1 Like