I am using the Chexpert dataset (found here on kaggle) to build a CNN model that can predict disease conditions (e.g. cardiomegaly, pleural effusion, atelectasis, etc) from chest x-ray image (multi-label classification). I am using PyTorch lightning and my code is attached to this question. I have tried several architectures and I don’t seem to get the models to perform well. I performed the overfit test (in which I try to overfit a model on one batch of data) and the models were able to overfit the single batch - showing that they are capable of fitting the data. However, regardless of the architecture I use, there is quite a difference between training loss (which can get as low as 0.2) and validation (which can get as low as 0.49). On sensitivity and precision (the metrics I am interested in), the models perform terribly during validation. After leaving the models for longer epochs, I also observed that the loss values start to increase. I will appreciate any help or suggestion to help me solve this problem. Thank you.
@ptrblck
import torch.nn as nn
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms
import torch.nn.functional as F
import matplotlib.pyplot as plt
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from sklearn.metrics import roc_auc_score
from pytorch_lightning.trainer.states import RunningStage, TrainerFn, TrainerState, TrainerStatus
import numpy as np
import time
import pandas as pd
import gc
import random
from chexpertloader import ChestXrayDataset
# from ipykernel import kernelapp as app
from torch.utils.tensorboard import SummaryWriter
import torchmetrics
from torchmetrics import AUROC
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.metrics import confusion_matrix
seed = 123
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
backbones = {
'efficientnetb0': models.efficientnet_b0(weights='IMAGENET1K_V1'),
'efficientnetb1': models.efficientnet_b1(weights='IMAGENET1K_V1'),
'efficientnetb2': models.efficientnet_b2(weights='IMAGENET1K_V1'),
'efficientnetb3': models.efficientnet_b3(weights='IMAGENET1K_V1'),
'efficientnetb4': models.efficientnet_b4(weights='IMAGENET1K_V1'),
'efficientnetb5': models.efficientnet_b5(weights='IMAGENET1K_V1'),
'efficientnetb6': models.efficientnet_b6(weights='IMAGENET1K_V1'),
'efficientnetb7': models.efficientnet_b7(weights='IMAGENET1K_V1'),
'densenet121': models.densenet121(weights='IMAGENET1K_V1'),
'densenet161': models.densenet161(weights='IMAGENET1K_V1'),
'densenet169': models.densenet169(weights='IMAGENET1K_V1'),
'densenet201': models.densenet201(weights='IMAGENET1K_V1'),
'resnet50': models.resnet50(weights='IMAGENET1K_V1'),
'efficientnetV2_m': models.efficientnet_v2_m(weights='IMAGENET1K_V1')
}
class LitEfficientNet(pl.LightningModule):
def __init__(self, arch, num_classes, lr):
super(LitEfficientNet, self).__init__()
self.arch = arch
self.lr = lr
self.sizes = {
'efficientnetb0': (256, 224), 'efficientnetb1': (256, 240), 'efficientnetb2': (288, 288), 'efficientnetb3': (320, 300),
'efficientnetb4': (384, 380), 'efficientnetb5': (489, 456), 'efficientnetb6': (561, 528), 'efficientnetb7': (633, 600),
'densenet121':(256,256), 'densenet161':(256,256), 'densenet169':(256,256), 'densenet201':(256,256),
'resnet50':(224,224), 'efficientnetV2_m':(384,384)
}
self.batch_sizes = {
'efficientnetb0': 64, 'efficientnetb1': 64, 'efficientnetb2': 64, 'efficientnetb3': 32,
'efficientnetb4': 20, 'efficientnetb5': 7, 'efficientnetb6': 5, 'efficientnetb7': 2,
'densenet121':64, 'densenet161':32, 'densenet169':32, 'densenet201':32, 'resnet50':32,
'efficientnetV2_m':16
}
self.model = backbones[arch]
if 'densenet' in arch:
self.model.classifier = nn.Sequential(
nn.Linear(self.model.classifier.in_features, 2048),
nn.ReLU(),
nn.Dropout(p=0.6),
nn.Linear(2048, 512),
nn.ReLU(),
nn.Dropout(p=0.2),
nn.Linear(512, num_classes),
)
elif 'resnet' in arch:
self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
elif 'efficientnet' in arch:
self.model.classifier = nn.Sequential(
nn.Dropout(p=self.model.classifier[0].p, inplace=True),
nn.Linear(self.model.classifier[1].in_features, num_classes),
)
def forward(self, x):
y_pred = self.model(x)
return y_pred
def training_step(self, batch, batch_idx):
images, labels = batch
# Forward pass
m = nn.Sigmoid()
outputs = self.model(images)
classes = {0:'Cardiomegaly', 1:'Edema', 2:'Atelectasis',
3:'Pleural Effuion', 4:'Lung Opacity'
}
Loss = nn.BCEWithLogitsLoss()
loss = Loss(outputs, labels)
self.log('train_loss', loss, sync_dist=True)
return loss
def train_dataloader(self):
train_csv = pd.read_csv('CheXpert-v1.0-small/train.csv')
train_csv.fillna(0, inplace=True)
train_dataset = ChestXrayDataset("CheXpert-v1.0-small/train", train_csv, self.sizes[self.arch], True)
# Data loader
train_loader = torch.utils.data.DataLoader(
dataset=train_dataset, batch_size=self.batch_sizes[self.arch], num_workers=8, shuffle=False
)
return train_loader
def validation_step(self, batch, batch_idx):
images, labels = batch
images = images
m = nn.Sigmoid()
outputs = self.model(images)
classes = {0:'Cardiomegaly', 1:'Edema', 2:'Atelectasis',
3:'Pleural Effuion', 4:'Lung Opacity'
}
Loss = nn.BCEWithLogitsLoss()
loss = Loss(outputs, labels)
self.log('val_loss', loss, sync_dist=True)
tensorboard_logs = {'val_loss': loss}
return loss
def val_dataloader(self):
validation_csv = pd.read_csv('CheXpert-v1.0-small/valid.csv')
validation_csv.fillna(0, inplace=True)
validation_csv = validation_csv.sample(frac=1)
validation_dataset = ChestXrayDataset("CheXpert-v1.0-small/valid", validation_csv, self.sizes[self.arch], True)
# Data loader
validation_loader = torch.utils.data.DataLoader(
dataset=validation_dataset, batch_size=self.batch_sizes[self.arch], num_workers=8, shuffle=False
)
return validation_loader
def configure_optimizers(self):
optimizer = optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
return optimizer
if __name__ == '__main__':
archs = ['efficientnetV2_m']
learning_rates = [0.001]
num_classes = 5
for i in range(len(learning_rates)):
arch = archs[0]
learning_rate = learning_rates[i]
logger = TensorBoardLogger(f"tb_logs_binary/{arch}",name=f"{arch}_{learning_rate}_ppv_npv_sensitive")
model = LitEfficientNet(arch,num_classes, learning_rate)
trainer = Trainer(
log_every_n_steps=1411,
logger=logger,
accelerator='gpu',
devices=-1,
# devices=1,
# overfit_batches=10,
max_epochs=50,
val_check_interval=0.1,
deterministic=True,
fast_dev_run=False)
trainer.fit(model)
del model, trainer
gc.collect()
`