Hi all,
I am considering that similar or related topics have been asked before (such as in weightsofcrossentropylossforvalidationdevset and applyweightedrandomsamplertovalidationtestsplitsmakessense). However, I still have some confusion as those and previous post include different code snipets for me to understand the complete scenario.
My case is as follows: I am working on a model for multiclass classification with a BERT Model with customized layers. However, the classes are not balanced, and I am trying to validate if working on this can improve the result metrics.
I found 3 main suggestions (besides working with the data asis) in different articles and previous posts, which are undersampling, oversampling and use of class weights.
I am currently testing the last one, but I am not sure if I am doing so in the right way, or if there is any reasoning that I am missing from those examples.
The steps that I have seen in previous examples are:

Sample the data into test, train and validation (standard step, no issue here)

Generate the class weights for later cross entropy calculation. For this step, I have seen different approaches, but the one I have found the most is something similar to this:
from sklearn.utils.class_weight import compute_class_weight
# compute the class weights
class_wts = compute_class_weight('balanced', np.unique(train_labels), train_labels)
print(class_wts)
# convert class weights to tensor
weights= torch.tensor(class_wts,dtype=torch.float)
weights = weights.to(device)
2.5. Sometimes, the weights are immediately assigned to a temporary new function:
# loss function
cross_entropy = nn.NLLLoss(weight=weights)
 After defining the model and the hyperparameters, proceed with the training. If the weights were not incorporated in a cross_entropy function before, they are directly created now:
def train():
model.train()
total_loss, total_accuracy = 0, 0
total_preds=[]
for step,batch in enumerate(train_dataloader):
batch = [r.to(device) for r in batch]
sent_id, mask, labels = batch
model.zero_grad()
preds = model(sent_id, mask)
loss = cross_entropy(preds, labels)
# or # criterion = torch.nn.CrossEntropyLoss(weight=weights,reduction='mean')
# and then # loss = criterion(outputs[1], b_labels)
total_loss = total_loss + loss.item()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
preds=preds.detach().cpu().numpy()
total_preds.append(preds)
avg_loss = total_loss / len(train_dataloader)
total_preds = np.concatenate(total_preds, axis=0)
return avg_loss, total_preds
Until here I had little issues following the logic. But the next step is the one that confuses me.
 After training, the validation is being done, but also adding the cross_entropy funtion.
def evaluate():
model.eval()
total_loss, total_accuracy = 0, 0
total_preds = []
for step,batch in enumerate(val_dataloader):
if step % 50 == 0 and not step == 0:
elapsed = format_time(time.time()  t0)
print(' Batch {:>5,} of {:>5,}.'.format(step, len(val_dataloader)))
batch = [t.to(device) for t in batch]
sent_id, mask, labels = batch
with torch.no_grad():
preds = model(sent_id, mask)
loss = cross_entropy(preds,labels) #< Here
total_loss = total_loss + loss.item()
preds = preds.detach().cpu().numpy()
total_preds.append(preds)
avg_loss = total_loss / len(val_dataloader)
total_preds = np.concatenate(total_preds, axis=0)
return avg_loss, total_preds
However, the cross_entropy function is the same as it was when created, which means that it is using the weights calculated from the train dataset.
In that regard, my questions are:
 Is it correct to use weights during the validation method, or should the validation method calculate the loss as usual?
 If we are supposed to use weights, should they be recalculated from the validation dataset? Or is it correct to use the ones from the training dataset?
 I am assuming that for the test dataset no weight should be used, since the goal is to test the model with what is realworld data. Is this correct or there am I confusing the logic for this final part?
Thanks in advance for any help and suggestion!