Runtime Error while using Focal Loss

There’s some issue with your predictions and targets.

From the docs:

Let’s assume we have 4 classes. The predictions should be of shape torch.Size([1, 4]) and targets of shape torch.Size([1])

I made a few modifications to your code and this seems to work for me.

class WeightedFocalLoss(nn.Module):
    def __init__(self, alpha=.25, gamma=2):
        super(WeightedFocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.criterion = nn.CrossEntropyLoss(reduce='none')

    def forward(self, inputs, targets):
        
        CE_loss = self.criterion(inputs, targets)
        pt = torch.exp(-CE_loss)
        F_loss = self.alpha * (1-pt)**self.gamma * CE_loss

        return F_loss

criterion = WeightedFocalLoss()

outputs = torch.Tensor([[0, 0, 0, 1]])
targets = torch.Tensor([3]).long()

loss = criterion(outputs, targets)

print(loss)
#tensor(0.0512)
1 Like