Weighted Multi-label Focal Loss Implementation

I want to confirm the below implementation for a Multi-label Focal Loss function that also accepts the class_weights parameter to handle class imbalance (@ptrblck would like to get your feedback if possible :slight_smile: ):

class MultiLabelFocalLoss(torch.nn.Module):
ā€œā€"Implementation of a Multi-label Focal loss function

    Args:
        weight: class weight vector to be used in case of class imbalance
        gamma: hyper-parameter for the focal loss scaling.
"""
def __init__(self, weight=None, gamma=2):
    super(MultiLabelFocalLoss, self).__init__()
    self.gamma = gamma
    self.weight = weight #weight parameter will act as the alpha parameter to balance class weights
    self.loss = torch.nn.BCELoss(reduction='none')

def forward(self, outputs, targets):
    ce_loss = self.loss(outputs, targets) * self.weight 
    pt = torch.exp(-ce_loss)
    focal_loss = ((1-pt)**self.gamma * ce_loss).mean() # mean over the batch
    return focal_loss