Not sure if validation predictions were collected right

I am working on binary classification where there are 2 inputs (image and numerical data) and one output (sigmoid). I need to perform a 5-fold cross validation and plot ROC curves for each fold. This is a code snippet of model.eval:

        with torch.no_grad():
            valid_preds_fold = np.zeros((x_val_fold.size(0)))  # (359,)
            ii = 0
            for x_img_batch, x_num_batch, y_batch in valid_loader: # THIS LOOP
                y_val_pred = model(x_img_batch, x_num_batch).detach()
                valid_preds_fold[ii * batch_size:(ii + 1) * batch_size] = sigmoid(y_val_pred.cpu().numpy())[:, 0]
                ii += 1

    fpr, tpr, thresholds = roc_curve(y_val_fold.cpu(), valid_preds_fold) # torch.Size([359, 1]) AND
    tprs.append(np.interp(mean_fpr, fpr, tpr))
    tprs[-1][0] = 0.0
    roc_auc = auc(fpr, tpr)
    plt.plot(fpr, tpr, lw=1, alpha=0.3, label='fold %d (AUC = %0.3f)' % (i + 1, roc_auc))

ax.plot([0, 1], [0, 1], linestyle='--', lw=2, color='r', label='level', alpha=.8)
mean_tpr = np.mean(tprs, axis=0)
mean_tpr[-1] = 1.0
mean_auc = auc(mean_fpr, mean_tpr)
std_auc = np.std(aucs)
ax.plot(mean_fpr, mean_tpr, color='b', label=r'meanAUC = %0.3f $\pm$ %0.2f)' % (mean_auc, std_auc), lw=2, alpha=.8)

for loop which is iterating through valid_loader looks suspicious to me. Basically, what I need to do is that I need to collect a batch of validation predictions (y_val_pred) to each validation fold’s list (valid_preds_fold) to then calculate the fpr, tpr, thresholds = roc_curve(y_val_fold.cpu(), valid_preds_fold). I have referred to a couple of resources to come this end but the AUC score of each fold is too low than expected.

Any code inspection to improve the above code is appreciated. Thanks

I’m not quite sure how this code is working exactly, but are you using different valid_loaders for each fold?

yes, like in 5-fold cross validation, I am training 4 folds for training and 1 fold for validation

I think as long as you make sure the indices do not overlap or are reused, your code looks alright.
What kind of issue are you seeing at the moment?

1 Like

I have no errors right now, I managed to inspect some little issues, thanks a lot. Just wanted to be sure that that particular part was executed properly.