Is this a correct implementation and use of focal loss for binary classification on vision transformer output? If it is correct, why are all train and val preds still stuck at zero?

I am using the following code snippet for focal loss for binary classification on the output of vision transformer. Vision Transformer in my case throws two values as output. So, I used a sigmod of difference of the two outputs as follows below. Could you please confirm if it is correct?

class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
    def forward(self, inputs, targets):
        BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        pt = torch.exp(-BCE_loss) # prevents nans when probability 0
        F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss
        return F_loss.mean()
criterion = FocalLoss()

m = nn.Sigmoid()

and then inside the train:

if train:
        print('training...')
        torch.autograd.set_detect_anomaly(True)
        for i_batch, sample_batched in enumerate(dataloader_train):  
            #pdb.set_trace()        
            feats = torch.stack(sample_batched['image']) 
            labels = torch.as_tensor(sample_batched['label']).cuda() 
            print('feats shape: ', feats.shape)
            print('labels shape: ', labels.shape)
            output = model(feats)
            loss = criterion(m(output[:,1]-output[:,0]), labels.float())
            #loss = criterion(output, labels)
            print('train loss is: ', loss)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            acc = (output.argmax(dim=1) == labels).float().mean()
            train_preds = output.argmax(dim=1)

I have this:

train_epoch_accuracy:  0.84375
not test
Evaluating...
epoch is:  49
evaluating...
epoch val acc:  tensor(0.8541, device='cuda:0')
val_epoch_accuracy:  0.8426966292134831
best val acc:  tensor(0.8541, device='cuda:0')
best epoch:  0
best preds:  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
best val labels:  [0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0]

As you see all predicted values are 0 which is the majority class (class 0 is 84% of the data).

I can’t comment on the correctness of your custom focal loss implementation as I’m usually using the multi-class implementation from e.g. kornia.

As described in the great post by @KFrank here (and also mentioned by me in an answer to another of your questions) you either use nn.BCEWithLogitsLoss for the binary classification or e.g. nn.CrossEntropyLoss if you are treating the use case as a 2-class multi-class classification.
torch.argmax(output, dim=1) would return the predictions for the latter use case so you are currently mixing both approaches.
If your output has he shape [batch_size, 1] (as it should given you are using F.binary_cross_entropy_with_logits) you would have to apply a threshold to get the predictions e.g. via:

preds = output > 0.0
1 Like

Thank you so much for the mention of kornia. I am using that one as per your suggestion and I feel more confident to use given it is a fairly maintained packaged.

import kornia
kwargs = {"alpha": 0.25, "gamma": 2.0, "reduction": 'mean'}
loss = kornia.losses.focal_loss(output, labels, **kwargs)