How to build a model to diagnose chest conditions from Chexpert dataset

I am using the Chexpert dataset (found here on kaggle) to build a CNN model that can predict disease conditions (e.g. cardiomegaly, pleural effusion, atelectasis, etc) from chest x-ray image (multi-label classification). I am using PyTorch lightning and my code is attached to this question. I have tried several architectures and I don’t seem to get the models to perform well. I performed the overfit test (in which I try to overfit a model on one batch of data) and the models were able to overfit the single batch - showing that they are capable of fitting the data. However, regardless of the architecture I use, there is quite a difference between training loss (which can get as low as 0.2) and validation (which can get as low as 0.49). On sensitivity and precision (the metrics I am interested in), the models perform terribly during validation. After leaving the models for longer epochs, I also observed that the loss values start to increase. I will appreciate any help or suggestion to help me solve this problem. Thank you.

import torch.nn as nn
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms
import torch.nn.functional as F
import matplotlib.pyplot as plt
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from sklearn.metrics import roc_auc_score
from pytorch_lightning.trainer.states import RunningStage, TrainerFn, TrainerState, TrainerStatus
import numpy as np
import time

import pandas as pd
import gc
import random

from chexpertloader import ChestXrayDataset
# from ipykernel import kernelapp as app

from torch.utils.tensorboard import SummaryWriter
import torchmetrics
from torchmetrics import AUROC
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.metrics import confusion_matrix

seed = 123
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

backbones = {
    'efficientnetb0': models.efficientnet_b0(weights='IMAGENET1K_V1'),
    'efficientnetb1': models.efficientnet_b1(weights='IMAGENET1K_V1'),
    'efficientnetb2': models.efficientnet_b2(weights='IMAGENET1K_V1'),
    'efficientnetb3': models.efficientnet_b3(weights='IMAGENET1K_V1'),
    'efficientnetb4': models.efficientnet_b4(weights='IMAGENET1K_V1'),
    'efficientnetb5': models.efficientnet_b5(weights='IMAGENET1K_V1'),
    'efficientnetb6': models.efficientnet_b6(weights='IMAGENET1K_V1'),
    'efficientnetb7': models.efficientnet_b7(weights='IMAGENET1K_V1'),
    'densenet121': models.densenet121(weights='IMAGENET1K_V1'),
    'densenet161': models.densenet161(weights='IMAGENET1K_V1'),
    'densenet169': models.densenet169(weights='IMAGENET1K_V1'),
    'densenet201': models.densenet201(weights='IMAGENET1K_V1'),
    'resnet50': models.resnet50(weights='IMAGENET1K_V1'),
    'efficientnetV2_m': models.efficientnet_v2_m(weights='IMAGENET1K_V1')

class LitEfficientNet(pl.LightningModule):
    def __init__(self, arch, num_classes, lr):
        super(LitEfficientNet, self).__init__()
        self.arch = arch = lr
        self.sizes = {
            'efficientnetb0': (256, 224), 'efficientnetb1': (256, 240), 'efficientnetb2': (288, 288), 'efficientnetb3': (320, 300),
            'efficientnetb4': (384, 380), 'efficientnetb5': (489, 456), 'efficientnetb6': (561, 528), 'efficientnetb7': (633, 600),
            'densenet121':(256,256), 'densenet161':(256,256), 'densenet169':(256,256), 'densenet201':(256,256),
            'resnet50':(224,224), 'efficientnetV2_m':(384,384)
        self.batch_sizes = {
                'efficientnetb0': 64, 'efficientnetb1': 64, 'efficientnetb2': 64, 'efficientnetb3': 32,
                'efficientnetb4': 20, 'efficientnetb5': 7, 'efficientnetb6': 5, 'efficientnetb7': 2,
                'densenet121':64, 'densenet161':32, 'densenet169':32, 'densenet201':32, 'resnet50':32,
        self.model = backbones[arch]
        if 'densenet' in arch:
            self.model.classifier = nn.Sequential(
                nn.Linear(self.model.classifier.in_features, 2048),

                nn.Linear(2048, 512),

                nn.Linear(512, num_classes),

        elif 'resnet' in arch:
             self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)

        elif 'efficientnet' in arch:
              self.model.classifier = nn.Sequential(
                  nn.Dropout(p=self.model.classifier[0].p, inplace=True),
                  nn.Linear(self.model.classifier[1].in_features, num_classes),

    def forward(self, x):
        y_pred = self.model(x)

        return y_pred

    def training_step(self, batch, batch_idx):
        images, labels = batch

        # Forward pass
        m = nn.Sigmoid()
        outputs = self.model(images)

        classes = {0:'Cardiomegaly', 1:'Edema', 2:'Atelectasis',
                   3:'Pleural Effuion', 4:'Lung Opacity'
        Loss = nn.BCEWithLogitsLoss()
        loss = Loss(outputs, labels)
        self.log('train_loss', loss, sync_dist=True)
        return loss

    def train_dataloader(self):

        train_csv = pd.read_csv('CheXpert-v1.0-small/train.csv')
        train_csv.fillna(0, inplace=True)
        train_dataset = ChestXrayDataset("CheXpert-v1.0-small/train", train_csv, self.sizes[self.arch], True)
        # Data loader
        train_loader =
            dataset=train_dataset, batch_size=self.batch_sizes[self.arch], num_workers=8, shuffle=False
        return train_loader
    def validation_step(self, batch, batch_idx):
        images, labels = batch
        images = images

        m = nn.Sigmoid()
        outputs = self.model(images)
        classes = {0:'Cardiomegaly', 1:'Edema', 2:'Atelectasis',
                   3:'Pleural Effuion', 4:'Lung Opacity'

        Loss = nn.BCEWithLogitsLoss()
        loss = Loss(outputs, labels)
        self.log('val_loss', loss, sync_dist=True)
        tensorboard_logs = {'val_loss': loss}
        return loss
    def val_dataloader(self):
      validation_csv = pd.read_csv('CheXpert-v1.0-small/valid.csv')
      validation_csv.fillna(0, inplace=True)
      validation_csv = validation_csv.sample(frac=1)
      validation_dataset = ChestXrayDataset("CheXpert-v1.0-small/valid", validation_csv, self.sizes[self.arch], True)
      # Data loader
      validation_loader =
          dataset=validation_dataset, batch_size=self.batch_sizes[self.arch], num_workers=8, shuffle=False
      return validation_loader
    def configure_optimizers(self):
        optimizer = optimizer = torch.optim.Adam(self.parameters(),
        return optimizer
if __name__ == '__main__':
    archs = ['efficientnetV2_m']
    learning_rates = [0.001]
    num_classes = 5

    for i in range(len(learning_rates)):

        arch = archs[0]
        learning_rate = learning_rates[i]

        logger = TensorBoardLogger(f"tb_logs_binary/{arch}",name=f"{arch}_{learning_rate}_ppv_npv_sensitive")
        model = LitEfficientNet(arch,num_classes, learning_rate)
        trainer = Trainer(
                                # devices=1,
                                # overfit_batches=10,

        del model, trainer


Have you considered hyperparameter tuning and starting with the combination that gives you the best result and then fine tuning the parameters individually until you get the best possible result? I have not used PyTorch Lightning before, you didn’t mention if you have already used the Performance and bottleneck profiler of PyTorch Lightning. That might be a thing you could look into to get you started.

1 Like

Thank you so much for your response @tiramisuNcustard. I have not used the Performance and bottleneck profiler before but I will check it out. Also, for the hyperparameter tuning, which set of hyperparameters will you recommend I start with, and if there are any tools I can use to automate the tuning process? Thank you once again.

@Issah_Samori, I have used Ray Tune with PyTorch before and I would recommend it (see the first link at the end). For CNN, I have tuned the learning rate, batch_size, momentum, epoch and the number of nodes in the hidden fully connected layers in the past. You can use the example in the second link below to get started with hyperparameters tuning with Ray Tune.

1 Like

Thanks @tiramisuNcustard. I will check it out.