Runtime Error while using Focal Loss

Hi everyone,
I am doing a multi-class classification.
I am trying to use focal loss instead of cross-entropy but it is throwing me an error -

class WeightedFocalLoss(nn.Module):
    def __init__(self, alpha=.25, gamma=2):
        super(WeightedFocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, inputs, targets):
       CE_loss = nn.CrossEntropyLoss(inputs, targets, reduce='none')
       pt = torch.exp(-CE_loss)
       F_loss = self.alpha * (1-pt)**self.gamma * CE_loss

   return F_loss

In the training script, I used in the following way -

criterion = focal_loss.WeightedFocalLoss()
#In training loop -
output = model(input_Data)
loss = criterion(output, target)

I am getting the below error -

RuntimeError: bool value of Tensor with more than one value is ambiguous

There’s some issue with your predictions and targets.

From the docs:

Let’s assume we have 4 classes. The predictions should be of shape torch.Size([1, 4]) and targets of shape torch.Size([1])

I made a few modifications to your code and this seems to work for me.

class WeightedFocalLoss(nn.Module):
    def __init__(self, alpha=.25, gamma=2):
        super(WeightedFocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.criterion = nn.CrossEntropyLoss(reduce='none')

    def forward(self, inputs, targets):
        
        CE_loss = self.criterion(inputs, targets)
        pt = torch.exp(-CE_loss)
        F_loss = self.alpha * (1-pt)**self.gamma * CE_loss

        return F_loss

criterion = WeightedFocalLoss()

outputs = torch.Tensor([[0, 0, 0, 1]])
targets = torch.Tensor([3]).long()

loss = criterion(outputs, targets)

print(loss)
#tensor(0.0512)
1 Like