For each class I have a weight (due to having underrepresented classes) and I would like to apply that weight to the corresponding samples in the batch when computing the loss function.
The weight parameters of pytorch loss functions expect a per-sample weight vector. So my question is, how do you efficiently convert a class weights dictionary (as in the code below) to a per-sample weight vector ?
In the code below, the true_labels
is the target vector that has values 0, 1, or 2. Is there an efficient way to implement class2sample_weights
?
Two probably inefficient approaches I considered are:
-
convert
true_labels
into numpy and apply dict key to value transform to get the sample vector weight; -
create a float tensor of ones (whose size is the number of samples) call it
weights
. Then loop over the class labels intrue_labels
where in each iteration get the indices of the class and set the correspondingweights
values to the appropriate class weight.
Is there a better approach ?
# key is the class label, value is the weight
class_weights = {0: 1/10, 1:8/10, 2:1/10}
sample_weights = class2sample_weights(true_labels, class_weights)
F.nll_loss(F.log_softmax(pred), true_labels, weights=sample_weights)