Deciding on a metric to save the checkpoint for best val that takes into account both sensitivity and accuracy of validation set for the imbalanced dataset in binary classification problem when using focal loss

How can I save the best model checkpoint for when I have a combination of best validation accuracy and best sensitivity? I have an imbalanced dataset with 16% of the data being class 1 and 84% of the data being class 0. I am using the focal loss with these arguments: gamma=3.0, alpha=0.25

I have this code for saving the best model checkpoint

based on best accuracy:

   if epoch_val_accuracy > best_val_acc:
                print('inside if - epoch is {}, val_acc is {}, and best_pred is {}'.format(epoch, epoch_val_accuracy, best_val_acc))
                best_val_acc = epoch_val_accuracy
                best_epoch = epoch
                best_preds = epoch_val_preds
                best_val_labels = epoch_val_labels
                print("Saving the best model...")
                torch.save(model.state_dict(), model_path + task_name + ".pth")

result is:


             Predicted Low  Predicted High
Actual Low              51              24
Actual High              9               5
best val acc:  tensor(0.8619, device='cuda:0')
best epoch:  39
best preds:  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
best val labels:  [0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0]
best val specifity:  1.0
best val sensitivity:  0.07142857142857142
best cm df:               Predicted Low  Predicted High
Actual Low              75               0
Actual High             13               1

based on best sensitivity:

if epoch_sensitivity > best_val_sensitivity:
                best_val_acc = epoch_val_accuracy
                best_epoch = epoch
                best_preds = epoch_val_preds
                best_val_labels = epoch_val_labels
                best_val_sensitivity = epoch_sensitivity
                best_val_specifity = epoch_specifity
                print("Saving the best model...")
                torch.save(model.state_dict(), model_path + task_name + ".pth")
best preds:  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
best val labels:  [0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0]
best val specifity:  0.0

Also, when I somehow tried to combine the accuracy and best sensitivity, like following

I got this result:

if epoch_sensitivity > best_val_sensitivity:
                if epoch_val_accuracy > 0.7:
                    best_val_acc = epoch_val_accuracy
                    best_epoch = epoch
                    best_preds = epoch_val_preds
                    best_val_labels = epoch_val_labels
                    best_val_sensitivity = epoch_sensitivity
                    best_val_specifity = epoch_specifity
                    print("Saving the best model...")
                    best_cm_df = pd.DataFrame(cm, 
                    columns = ['Predicted Low', 'Predicted High'],
                    index = ['Actual Low', 'Actual High'])
     
                    torch.save(model.state_dict(), model_path + task_name + ".pth")
best val acc:  tensor(0.7159, device='cuda:0')
best epoch:  36
best preds:  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0]
best val labels:  [0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0]
best val specifity:  0.7733333333333333
best val sensitivity:  0.2857142857142857
best cm df:               Predicted Low  Predicted High
Actual Low              58              17
Actual High             10               4

This depends entirely on your preferences, i.e. which outcomes are the most important to you.
That said, the F1 score aka Dice coefficient is a standard choice that could be a good start.

Best regards

Thomas

1 Like

Hi Thomas,

Thanks a lot for your reply. I used macro F1 score however, that ended up with saving a model that only has 1 one, in its prediction out of the 14 ones I have. This saved model has a very low sensitivity.

How do you suggest to combine Macro F1 Score with sensitivity such that both will be high enough?

best val acc:  tensor(0.8619, device='cuda:0')
best epoch:  48
best preds:  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
best val labels:  [0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0]
best val specifity:  1.0
best val sensitivity:  0.07142857142857142
best cm df:               Predicted Low  Predicted High
Actual Low              75               0
Actual High             13               1
best Macro F1 score:  0.5267893660531697
epoch_macro_F1 = f1_score(epoch_val_labels, epoch_val_preds, average='macro')
        if not test:
            if epoch_macro_F1 > best_macro_F1:
                best_val_acc = epoch_val_accuracy
                best_epoch = epoch
                best_preds = epoch_val_preds
                best_val_labels = epoch_val_labels
                best_val_sensitivity = epoch_sensitivity
                best_val_specifity = epoch_specifity
                best_macro_F1 = epoch_macro_F1
                best_cm_df = pd.DataFrame(cm, 
                columns = ['Predicted Low', 'Predicted High'],
                index = ['Actual Low', 'Actual High'])
                print("Saving the best model...")
                torch.save(model.state_dict(), model_path + task_name + ".pth")

Hi Tom, I used MCC, which is much better than Macro F1 score, however, the sensitivity score is still very low. Do you have any suggestion as to boost the sensitivity score?

best val acc:  tensor(0.7516, device='cuda:0')
best epoch:  5
best preds:  [0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0]
best val labels:  [0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0]
best val specifity:  0.8266666666666667
best val sensitivity:  0.21428571428571427
best cm df:               Predicted Low  Predicted High
Actual Low              62              13
Actual High             11               3
best Macro F1 score:  0.518918918918919
best MCC score:  0.038828658355903885

and the code is:

from sklearn.metrics import matthews_corrcoef
epoch_MCC = matthews_corrcoef(epoch_val_labels, epoch_val_preds)
if not test:
            if epoch_MCC > best_MCC:
                best_val_acc = epoch_val_accuracy
                best_epoch = epoch
                best_preds = epoch_val_preds
                best_val_labels = epoch_val_labels
                best_val_sensitivity = epoch_sensitivity
                best_val_specifity = epoch_specifity
                best_macro_F1 = epoch_macro_F1
                best_MCC = epoch_MCC
                best_cm_df = pd.DataFrame(cm, 
                columns = ['Predicted Low', 'Predicted High'],
                index = ['Actual Low', 'Actual High'])
                print("Saving the best model...")
                torch.save(model.state_dict(), model_path + task_name + ".pth")
            

Hi Mona,

well, so then you do have an opinion after all and can add sensitivity to the criterion (or use sensitivity with a high weight and specificity with a low weight).

Validation criterion aside, for imbalanced datasets it usually is a good idea to balance them for training (not validation). If you are seeing that all your training states have bad sensitivity, that might be a good step. Focal loss and weighting can help to some extend, but in my experience oversampling the minority class is very effective and relatively straightforward. I typically implement this at the dataset level (so I add additional “fake” members to the dataset that return copies of the minority class members).

Best regards

Thomas

1 Like