Hi,
Here is my implementation. I have tried to use the info on torch.nn.functional.log_softmax
from https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.log_softmax. So far, it’s working well in a class - imbalance problem.
Please let me know if it works for you.
import torch
import torch.nn as nn
import torch.nn.functional as F
class FocalLoss(nn.Module):
def __init__(self, weight=None,
gamma=2., reduction='none'):
nn.Module.__init__(self)
self.weight = weight
self.gamma = gamma
self.reduction = reduction
def forward(self, input_tensor, target_tensor):
log_prob = F.log_softmax(input_tensor, dim=-1)
prob = torch.exp(log_prob)
return F.nll_loss(
((1 - prob) ** self.gamma) * log_prob,
target_tensor,
weight=self.weight,
reduction = self.reduction
)