Performance problem with binary classification + Totally confused :(

I recently migrated to pytorch from TF, and now I’m facing a very stupid and embarrassing issue.
I’m trying to do a binary classification on an Xray dataset. The directory structure is as follows:

-test
---Normal
---Pneumonia
-train
---Normal
---Pneumonia

So I’m using the ImageFolder from torchvision. And here’s the augmentation and dataset class:

train_data = torchvision.datasets.ImageFolder('/kaggle/input/covid19-xray-dataset-train-test-sets/xray_dataset_covid19/train/')
val_data = torchvision.datasets.ImageFolder('/kaggle/input/covid19-xray-dataset-train-test-sets/xray_dataset_covid19/test/')

IMG_SIZE = 512

aug = A.Compose([A.Resize(IMG_SIZE, IMG_SIZE),
                 A.RandomCrop(IMG_SIZE, IMG_SIZE),
                 A.HorizontalFlip(p=0.5),
                 A.VerticalFlip(p=0.5),
                 A.Rotate(10),
                 A.Blur(),
                 A.RandomGamma(),
                 A.Sharpen(),
                 A.GaussNoise(p=0.1),
                 A.CLAHE(),
                 A.Normalize(mean=0, std=1),
                 ToTensorV2()])

class DataReader(Dataset):
    def __init__(self, dataset, transform):
        
        self.dataset = dataset
        self.transform = transform
        
    def __getitem__(self, index):
        image = self.dataset[index][0]
        label = self.dataset[index][1]
        
        image = np.array(image)
        # image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = self.transform(image=image)['image']
        
        return image, label
    
    def __len__(self):
        return len(self.dataset)

The model class is as follows (I omitted the on_validation_epoch_end and on_train_epoch_end; nothing crazy going on there):

class XrayModel(LightningModule):
    
    def __init__(self):
        
        super(XrayModel, self).__init__()

        self.model = models.efficientnet_b0(pretrained=True)
        self.model.classifier[1] = nn.Linear(in_features=1280, out_features=1)
        
        self.lr = 1e-4
        self.batch_size = 32
        self.numworker = 2
        
        self.criterion = nn.BCEWithLogitsLoss()
        self.acc = torchmetrics.Accuracy(task='binary')
        
        self.trainacc, self.valacc = [], []
        self.trainloss, self.valloss = [], []
    
    def forward(self, x):
        x = self.model(x)
        
        return x
    
    
    def configure_optimizers(self):
        opt = torch.optim.AdamW(params=self.parameters(), lr=self.lr)
        scheduler=ReduceLROnPlateau(opt,mode='min',  factor=0.75, patience=5)
        return {'optimizer': opt, 'lr_scheduler':scheduler, 'monitor':'val_loss'}
    
    
    def train_dataloader(self):

        train_loader = DataLoader(DataReader(train_data, aug), shuffle=True,
                                  batch_size=self.batch_size, num_workers=self.numworker)

        return train_loader
    
    
    def training_step(self, batch, batch_idx):
        image, label = batch
        pred = self(image)
        loss = self.criterion(pred.flatten(), label.float())  
        acc = self.acc(pred.flatten(), label)

        return loss
        
    def val_dataloader(self):
        val_loader = DataLoader(DataReader(val_data, aug), shuffle=False, batch_size=self.batch_size, num_workers=self.numworker)
        return val_loader

    
    def validation_step(self, batch, batch_idx):
        image, label = batch
        pred = self(image)
        loss = self.criterion(pred.flatten(), label.float())
        acc = self.acc(pred.flatten(), label)
        
        self.log('val_loss', loss)
        self.log('val_acc', acc)

        return loss

The problem: It’s as if the model doesn’t learn class ‘1’, and only predicts class ‘0’. Or maybe I’m doing something wrong when getting the preds?!
I mean, the acc for both training and validation is somewhat good, but the confusion matrix and ROCAUC score are at 0.5! Here’s the confusion matrix and code:

val_loader = DataLoader(DataReader(val_data, aug), shuffle=False, batch_size=20)
y_pred, y_true, y_probs   = [], [], []

with torch.no_grad():
    for batch in val_loader:
        x, labels = batch
        outputs = model(x)
        _, predicted = torch.max(outputs, 1)
        y_pred.extend(predicted.tolist())
        y_true.extend(labels.tolist())
        
        probabilities = torch.softmax(outputs, dim=1)
        y_probs.extend(probabilities.tolist())
    
confusion = confusion_matrix(y_true, y_pred)
auroc = roc_auc_score(y_true, y_probs)
print(auroc)
print(confusion)

>>> 0.5000
>>> [[20, 0],
     [20, 0]]

if the actual label is ‘0’, the model predicts 100% correctly. But if it’s ‘0’, then its 0%.
maybe I’m doing sth wrong in the DataReader class when returning the label?!
Maybe I should one-hot encode them?

1 Like

@Sand_Glokta

(Coming from Tensorflow, i think both are good but torch is just better.)

I wonder if it is possible that everything runs fine and just the confusion matrix is calc incorrectly?

Can you print the model output and labels so that we can see the numbers and format ?

Reading this torcheval.metrics.functional.binary_confusion_matrix — TorchEval main documentation

it seems pytorch has a simple built in way to calculate it, I’d rather try using that.

1 Like

Hey! Thanks for the reply!
So I tried the method you sent and that made me realize the shape of the outputs is weird, and that’s the problem!

LABELS: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
        
OUTPUTS: tensor([[-3.7406], [-4.1787], [-5.2584], [-3.9865],
                [-3.6685],[-3.9812], [-4.3075], [-4.7298],
                [-5.6243], [-5.1365], [-5.6333], [-4.7504],
                [-5.3502], [-3.3726], [-5.2753], [-5.6162],
                [-3.6993], [-5.3350],[-5.4266],[ 0.2309],
                [ 5.6679],[ 6.5182],[ 5.2023],[-0.9134],
                [ 3.1126],[ 1.0085],[ 6.1053],[ 5.4332],
                [ 4.5336],[ 4.8902],[ 3.9891],[ 4.4946],
                [ 6.1804],[ 5.9176],[ 6.5229],[ 5.6799],
                [ 7.0696],[ 4.8820],[ 5.4711],[ 5.0669]])

So basically it’s a tensor of shape [20, 1], which requires to be flattened and that gets the job done.
Thanks again!

nn.BCEWithLogitsLoss will raise an error if the input and target shape do not match:

criterion = nn.BCEWithLogitsLoss()
x = torch.randn(20, 1, requires_grad=True)
y = torch.randint(0, 2, (20,)).float()

loss = criterion(x, y)
# ValueError: Target size (torch.Size([20])) must be the same as input size (torch.Size([20, 1]))

Based on your previous description I would guess your model is overfitting to the majority class.
Did you check the class frequencies and if you are dealing with an imbalanced dataset?

2 Likes

I’d also add to this, at least from my newbie perspective, the data should be shuffled, and does not seem so from the output labels.

1 Like

Thanks for the reply!
Yes, the classes are balanced, each with 20 samples. The problem was that I could not understand BCEWithLogitsLoss and how it expects and handles the logits. Now, even the roc_curve from sklearn works:


from torcheval.metrics.functional import binary_confusion_matrix, binary_auroc
out = outputs.flatten()

print(binary_confusion_matrix(out, labels))
print(binary_auroc(out, labels))

>>> tensor([[20,  0],
            [ 1, 19]])
>>> tensor(1., dtype=torch.float64)


fpr, tpr, _ = roc_curve(labels, out)
roc_curve(labels, out)
>>> (array([0., 0., 0., 1.]),
     array([0.  , 0.05, 1.  , 1.  ]),
     array([ 8.758759 ,  7.758759 , -1.8544512, -5.5356064], dtype=float32))

I think that what Patrick means is that we could pass the LABELS and OUTPUT that you logged and they should either pass or be rejected. It’s important to know exactly why something fails imho as well. You seem to understand it though which is enough.

I may try what happens passing LABELS and OUTPUT to BCE loss once I get to a laptop.

Glad it works though, I wonder if in reality any software uses classification networks to check for diseases in images or this never gets deployed to the real world.

1 Like

Hi Mah!

Commercial AI / neural-network medical-image tools have recently become
available and are being used in the field. As just one random example, this
press release describes an AI-based breast-cancer screening service.

Quoting a little:

“Saige-Dx is a sophisticated, AI deep-learning algorithm that has been trained on more data than any one radiologist would see in a lifetime. Saige-Dx’s pivotal study showed that radiologists who used Saige-Dx improved their performance, something no other mammography AI tool has demonstrated.”

(This is a press release from the company that makes the tool, so, of course,
they want to make it sound good. But there are studies – generally paid for
by companies that want to get their tools approved for medical use – that show
that such tools can be useful, at least when used to assist trained, human
radiologists.)

Best.

K. Frank

2 Likes

Interesting, thank you!. I’m open to discuss some day about it. I’d love to do something with real benefit, but under GPL or maybe Apache2.0 license, no comercial benefit.

1 Like