I’ve been trying to use mixup with focal loss for my multi-class label training. I patch up some codes taken from various open-sources project for multi-class focal loss. However, I’m not able to modify it for label with probabilities (eg. mixup augmentation) rather than binary. Here is my code for one-hot type label:
def focal_loss(p, t, alpha=None, gamma=2.0):
loss_val = F.cross_entropy(p, t, weight=alpha, reduction='none')
p = F.softmax(p, dim=1)
n_classes = p.size(1)
# This is where I couldn't modify it for probabilities type label
t = t.unsqueeze(1)
shape = t.shape
target_onehot = torch.zeros(shape[0], n_classes, *shape[2:], dtype=t.dtype, device=t.device)
target_onehot.scatter_(1, t, 1.0)
focal_weight = 1 - torch.where(torch.eq(t, 1.), p, 1 - p)
focal_weight.pow_(gamma)
focal_weight.to(p.device)
if len(loss_val.shape) < len(focal_weight.shape):
loss_val = loss_val.unsqueeze(1)
# compute loss
alpha_weight = 1.0
focal_loss = focal_weight * alpha_weight * loss_val
return torch.mean(focal_loss)
To use it,
num_classes = 15
pred = torch.rand((16, 50, num_classes), dtype=torch.float)
label = (torch.rand((16, 50, num_classes)) > 0.5).long()
# Expected label type for mixup is:
# label = torch.rand((16, 50, num_classes), dtype=torch.float)
loss = focal_loss(pred.reshape(-1, num_classes), label.view(-1, num_classes), alpha=None)
Anyone with more knowledge on focal loss could help me with this? Many thanks in advance.