Custom loss function for specific label

Hello,

I would like to create a custom loss function that takes into consideration only one of the labels.

I have three labels (0, 1, 2) and I would like to consider only 0.

This is my code:

class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels")

        outputs = model(**inputs)
        logits = outputs.get("logits")
        
        label_tensor = labels.view(-1)
        output_tensor = logits.view(-1, self.model.config.num_labels)
        
        mask = label_tensor.eq(0)
        
        loss = torch.mean(torch.abs(torch.masked_select(output_tensor, mask) - torch.masked_select(label_tensor, mask)))
        return (loss, outputs) if return_outputs else loss

This is not working.

Any ideas?

The following works (I do not know whether it makes sense, but at least there are no errors):

class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels")

        outputs = model(**inputs)
        logits = outputs.get("logits")
        
        label_tensor = labels.view(-1)
        output_tensor = torch.argmax(logits, dim=-1)
        
        mask = label_tensor.eq(0)
        
        loss = (torch.masked_select(output_tensor, mask) - torch.masked_select(label_tensor, mask)).float().abs().mean()
        loss.requires_grad = True
        
        return (loss, outputs) if return_outputs else loss

Thanks a million!

The code looks wrong since it seems you are detaching the logits tensor by calling torch.argmax on it, which will raise a valid error explaining that backward() cannot be called on the loss tensor. It then seems you are trying to fix it by explicitly calling loss.requires_grad = True which will not re-attach the computation graph somehow and only masks the error.
The model parameters will thus not get any valid gradients.

Thank you. So perhaps the entire approach does not make sense. Or is there a workaround?

What I am doing right now is using weighted Cross Entropy Loss:

class_weights = compute_class_weight(class_weight='balanced', classes=np.unique(df['labels']), y=df['labels'])

c_w = np.fromiter((np.ceil(i/min(class_weights)) for i in class_weights), dtype=np.float32)
c_w[np.argmax(c_w)] = np.max(c_w) + 1

class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels")
        # forward pass
        outputs = model(**inputs)
        logits = outputs.get("logits")
        # compute custom loss (suppose one has 3 labels with different weights)
        loss_fct = nn.CrossEntropyLoss(weight=torch.tensor(c_w, device=model.device))
        loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
        return (loss, outputs) if return_outputs else loss

But this is sill fitting primarily the majority class.

However, I am interested in one of the minority classes.

If I increase the weight of the target class above a certain threshold the models just breaks.

I am trying upsampling…

Upsampling seems to do the trick in this case. The training is much more stable. Then we will see how well the model generalises.

My guess, I do not know if this makes any sense, is that the batch size for evaluation was too small to properly sample the minority classes. But I have very limited VRAM.