Hi all,
I am considering that similar or related topics have been asked before (such as in weights-of-cross-entropy-loss-for-validation-dev-set and apply-weightedrandomsampler-to-validation-test-splits-makes-sense). However, I still have some confusion as those and previous post include different code snipets for me to understand the complete scenario.
My case is as follows: I am working on a model for multi-class classification with a BERT Model with customized layers. However, the classes are not balanced, and I am trying to validate if working on this can improve the result metrics.
I found 3 main suggestions (besides working with the data as-is) in different articles and previous posts, which are undersampling, oversampling and use of class weights.
I am currently testing the last one, but I am not sure if I am doing so in the right way, or if there is any reasoning that I am missing from those examples.
The steps that I have seen in previous examples are:
-
Sample the data into test, train and validation (standard step, no issue here)
-
Generate the class weights for later cross entropy calculation. For this step, I have seen different approaches, but the one I have found the most is something similar to this:
from sklearn.utils.class_weight import compute_class_weight
# compute the class weights
class_wts = compute_class_weight('balanced', np.unique(train_labels), train_labels)
print(class_wts)
# convert class weights to tensor
weights= torch.tensor(class_wts,dtype=torch.float)
weights = weights.to(device)
2.5. Sometimes, the weights are immediately assigned to a temporary new function:
# loss function
cross_entropy = nn.NLLLoss(weight=weights)
- After defining the model and the hyperparameters, proceed with the training. If the weights were not incorporated in a cross_entropy function before, they are directly created now:
def train():
model.train()
total_loss, total_accuracy = 0, 0
total_preds=[]
for step,batch in enumerate(train_dataloader):
batch = [r.to(device) for r in batch]
sent_id, mask, labels = batch
model.zero_grad()
preds = model(sent_id, mask)
loss = cross_entropy(preds, labels)
# or # criterion = torch.nn.CrossEntropyLoss(weight=weights,reduction='mean')
# and then # loss = criterion(outputs[1], b_labels)
total_loss = total_loss + loss.item()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
preds=preds.detach().cpu().numpy()
total_preds.append(preds)
avg_loss = total_loss / len(train_dataloader)
total_preds = np.concatenate(total_preds, axis=0)
return avg_loss, total_preds
Until here I had little issues following the logic. But the next step is the one that confuses me.
- After training, the validation is being done, but also adding the cross_entropy funtion.
def evaluate():
model.eval()
total_loss, total_accuracy = 0, 0
total_preds = []
for step,batch in enumerate(val_dataloader):
if step % 50 == 0 and not step == 0:
elapsed = format_time(time.time() - t0)
print(' Batch {:>5,} of {:>5,}.'.format(step, len(val_dataloader)))
batch = [t.to(device) for t in batch]
sent_id, mask, labels = batch
with torch.no_grad():
preds = model(sent_id, mask)
loss = cross_entropy(preds,labels) #<------- Here
total_loss = total_loss + loss.item()
preds = preds.detach().cpu().numpy()
total_preds.append(preds)
avg_loss = total_loss / len(val_dataloader)
total_preds = np.concatenate(total_preds, axis=0)
return avg_loss, total_preds
However, the cross_entropy function is the same as it was when created, which means that it is using the weights calculated from the train dataset.
In that regard, my questions are:
- Is it correct to use weights during the validation method, or should the validation method calculate the loss as usual?
- If we are supposed to use weights, should they be recalculated from the validation dataset? Or is it correct to use the ones from the training dataset?
- I am assuming that for the test dataset no weight should be used, since the goal is to test the model with what is real-world data. Is this correct or there am I confusing the logic for this final part?
Thanks in advance for any help and suggestion!