Hi everyone,
I am doing a multi-class classification.
I am trying to use focal loss instead of cross-entropy but it is throwing me an error -
class WeightedFocalLoss(nn.Module):
def __init__(self, alpha=.25, gamma=2):
super(WeightedFocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
def forward(self, inputs, targets):
CE_loss = nn.CrossEntropyLoss(inputs, targets, reduce='none')
pt = torch.exp(-CE_loss)
F_loss = self.alpha * (1-pt)**self.gamma * CE_loss
return F_loss
In the training script, I used in the following way -
criterion = focal_loss.WeightedFocalLoss()
#In training loop -
output = model(input_Data)
loss = criterion(output, target)
I am getting the below error -
RuntimeError: bool value of Tensor with more than one value is ambiguous