Multi-class BERT Model: Class weights and where to use them

Hi all,

I am considering that similar or related topics have been asked before (such as in weights-of-cross-entropy-loss-for-validation-dev-set and apply-weightedrandomsampler-to-validation-test-splits-makes-sense). However, I still have some confusion as those and previous post include different code snipets for me to understand the complete scenario.

My case is as follows: I am working on a model for multi-class classification with a BERT Model with customized layers. However, the classes are not balanced, and I am trying to validate if working on this can improve the result metrics.
I found 3 main suggestions (besides working with the data as-is) in different articles and previous posts, which are undersampling, oversampling and use of class weights.
I am currently testing the last one, but I am not sure if I am doing so in the right way, or if there is any reasoning that I am missing from those examples.

The steps that I have seen in previous examples are:

  1. Sample the data into test, train and validation (standard step, no issue here)

  2. Generate the class weights for later cross entropy calculation. For this step, I have seen different approaches, but the one I have found the most is something similar to this:

from sklearn.utils.class_weight import compute_class_weight

# compute the class weights
class_wts = compute_class_weight('balanced', np.unique(train_labels), train_labels)


# convert class weights to tensor
weights= torch.tensor(class_wts,dtype=torch.float)
weights =

2.5. Sometimes, the weights are immediately assigned to a temporary new function:

# loss function
cross_entropy  = nn.NLLLoss(weight=weights) 
  1. After defining the model and the hyperparameters, proceed with the training. If the weights were not incorporated in a cross_entropy function before, they are directly created now:
def train():
    total_loss, total_accuracy = 0, 0
    for step,batch in enumerate(train_dataloader):
        batch = [ for r in batch]
        sent_id, mask, labels = batch
        preds = model(sent_id, mask)
        loss = cross_entropy(preds, labels)
        # or # criterion = torch.nn.CrossEntropyLoss(weight=weights,reduction='mean')
        # and then # loss = criterion(outputs[1], b_labels)
        total_loss = total_loss + loss.item()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    avg_loss = total_loss / len(train_dataloader)
    total_preds  = np.concatenate(total_preds, axis=0)
    return avg_loss, total_preds

Until here I had little issues following the logic. But the next step is the one that confuses me.

  1. After training, the validation is being done, but also adding the cross_entropy funtion.
def evaluate():  
    total_loss, total_accuracy = 0, 0
    total_preds = []
    for step,batch in enumerate(val_dataloader):
        if step % 50 == 0 and not step == 0:
            elapsed = format_time(time.time() - t0)
            print('  Batch {:>5,}  of  {:>5,}.'.format(step, len(val_dataloader)))
        batch = [ for t in batch]
        sent_id, mask, labels = batch
        with torch.no_grad():
            preds = model(sent_id, mask)
            loss = cross_entropy(preds,labels)  #<------- Here
            total_loss = total_loss + loss.item()
            preds = preds.detach().cpu().numpy()
    avg_loss = total_loss / len(val_dataloader) 
    total_preds  = np.concatenate(total_preds, axis=0)
    return avg_loss, total_preds

However, the cross_entropy function is the same as it was when created, which means that it is using the weights calculated from the train dataset.
In that regard, my questions are:

  1. Is it correct to use weights during the validation method, or should the validation method calculate the loss as usual?
  2. If we are supposed to use weights, should they be recalculated from the validation dataset? Or is it correct to use the ones from the training dataset?
  3. I am assuming that for the test dataset no weight should be used, since the goal is to test the model with what is real-world data. Is this correct or there am I confusing the logic for this final part?

Thanks in advance for any help and suggestion!

  1. Yes, the weighted loss can be reused to get a signal about overfitting as @KFrank explains in this post.
  2. Both datasets should come from the same domain and have the same or similar distributions as also explained in the linked post.
  3. For the final run on the test data you should calculate the metric which represents your target the closest. E.g. calculating the accuracy for an imbalanced use case could show a high value, while your model might be useless and predict only the majority class due to the accuracy paradox.
1 Like

I think now I have a more clear idea of the differences in the steps between oversampling and the use of class weights.

Thanks for the reply and links!