Focal loss in pytorch

I have binary NLP classification problem and my data is very biased. Class 1 represents only 2% of data. For training I am oversampling from class 1 and for training my class distribution is 55%-45%. I have built a CNN. My last few layers and loss function as below

self.batch_norm2 = nn.BatchNorm1d(num_filters)
self.fc2 = nn.Linear(np.sum(num_filters), fc2_neurons)
self.batch_norm3 = nn.BatchNorm1d(fc2_neurons)
self.fc3 = nn.Linear(fc2_neurons, 1)#changing on 6March - BCE with logits loss



In my evaluation function I am calling that loss as follows

loss=BCE_With_LogitsLoss(torch.squeeze(probs), labels.float())

I was suggested to use focal loss over here.

Please consider using Focal loss:
Tsung-Yi Lin, Priya Goyal, Ross Girshick, Kaiming He, Piotr Dollár Focal Loss for Dense Object Detection (ICCV 2017).

Is there any pytorch implementation of the same? I found few but now sure which are correct.
one example -
another example -

My questions:

  1. is there any specific implementation that I should use?
  2. Should I use focal loss even though I am oversampling?
  3. How to modify my code to use the correct implementation
  4. do we need to use pos_weight along with it?
  1. I would probably use Kornia’s implementation from here (CC @edgarriba)
  2. I’m unsure what your current oversampling strategy is as each sample contains multiple targets (each pixel is associated with a class)
  3. Unsure, how to understand the question, but you should be able to use the linked loss implementation directly in your code after installing kornia
  4. I don’t think focal loss uses pos_weights
1 Like

appreciate the answer

  1. Thanks for the link.
  2. trying to clarify - This is a NLP problem and I dont have any images as input. Lets say my master data has 100,000 examples of class 0 and 20,000 class 1 then my training data has 10,000 class 0 and 10,000 class 1. I oversampled from class 1. Each example is some text and associate y could be either 0 or 1. My test data still has 98%-2% distribution. In such case of oversampling should I still us e focal loss?

I’m not familiar with using focal loss for NLP use cases and don’t know if it would work (it’s definitely worth an experiment so let us know how it went :wink: ).

1 Like

I will definitely update with my results.

quick question - why 2 lines related to BCE_With_LogitsLoss loss work but similar 2 lines with focal loss fail?

loss=BCE_With_LogitsLoss(torch.squeeze(probs), labels.float())

Focal loss - below lines give error. is it because for focal loss input and target are mandatory? why they are not mandatory for BCE_With_LogitsLoss??

loss=focal_loss(torch.squeeze(probs), labels.float())

I don’t know where binary_focal_loss_with_logits is coming from, as it doesn’t seem to be defined in kornia.
I also think that both losses expect to get a target tensor as well as an input and don’t know which error you are seeing.

BCE_With_LogitsLoss is from pytorch directly as shown in the below line

I will post error with focal loss soon

so far not seeing huge improvement as compared to nn.BCEWithLogitsLoss