I’m working on a classification problem. The number of classes is 5000. I have a ground truth vector that has the shape (1000) instead of 1. The values in this target vector are the possible classes and the predicted vector is of the shape (1x5000) which holds the softmax scores for all the classes. Which loss function is recommended for such a use case?
Since you’ve used softmax, it should be a multi-class classification problem (rather than multi-label).
Cross Entropy loss is the most used for such cases.
However, what about your target vector?
It should be the same shape as the predicted outputs. Better to have it in a one-hot encoded fashion with 1 corresponding to the index of the right class, 0s otherwise.
Ground truth vector is the target vector
predicted_vector = tensor([0.0669, 0.1336, 0.3400, 0.3392, 0.1203] ground_truth = tensor([3,2,5])
For the above illustration, a typical argmax operation would result in declaring class 3 as the predicted class (0.34) but I want the model to reward even if the argmax class is any of 3,2, or 5
I didn’t quite understand you on this part -
Hi Saswat (and Srishti)!
It’s not clear what your use case is.
The first question, as Srishti mentioned, is whether you have a multi-label,
multi-class problem or a multi-class (single-label) problem.
Conceptually, is a given sample in exactly one class (say, class 3), but for
training purposes, predicting class 2 or 5 is still okay so you don’t want to
penalize your model too heavily?
Or, conceptually, can a given sample be in multiple classes at the same
time, and in your example your sample really is in classes 3 and 2 and 5,
and you would like to penalize your model if it doesn’t predict all three (and
also penalize it if it were additionally to predict other classes, for example,
class 4 or class 7)?
In the first case you have a single-label, multi-class problem, but with
probabilistic (“soft”) labels, and you should use
(and not use
softmax()). In your example your (soft) target might be
a probability of 0.7 for class 3, a probability of 0.2 for class 2, and a
probability of 0.1 for class 5 (and zero for everything else).
In the second case you have a multi-label, multi-class problem, and you
BCEWithLogitsLoss (and no
In this case your multi-label target might be 1.0 for classes 3, 2, and 5
(and 0.0 for all other classes). You can also use probabilities for targets
BCEWithLogitsLoss, e.g., perhaps, 0.9 for class 3, 0.8 for class 2,
and 0.7 for class 5 (and 0.0 or close to 0.0 for all other classes).
First start with conceptually what kind of classification problem you have
and then drill down as to what loss function you should use and how you
should structure / interpret your target data.
This was really elaborate and helpful. I’m dealing with the first case that you mentioned. Could you please list any pointers as to how to implement the CE loss function with soft labels?
As @KFrank has mentioned, your target vector should look like:
target = torch.tensor([0.1, 0.2, 0, 0, 0, 0.7]) # Let's say we have 6 classes only
essentially, it should contain the probabilities corresponding to each class.
output tensor could then be used to calculate the Cross-entropy loss:
loss_fn = torch.nn.CrossEntropyLoss() loss = loss_fn(output, target)
Note that you do not have to include in your model a softmax layer explicitly. From the docs-
Got it @srishti-git1110. Thanks!
@KFrank I have a silly follow-up question on this.
You’ve asked me not to use softmax. Is it just because
CrossEntropyLoss applies softmax inherently? Or is it because of some other reason?
@Saswat This is the reason.
@KFrank If all classes in the target are equally likely, won’t the model get penalized for predicting only one of them and not predicting others?
For example, if the target vector is
[0,1,1,1,0,1,0,0,0,0] and predicting 1 of classes 1,2,3 or 5 is fine and the model is doing so at some point, won’t “not-predicting” other classes send negative feedback to readjust the weights? I tried normalizing the target class values (to make them soft, since all have equal probability) so the new target looks like this `[0,0.25,0.25,0.25,0,0.25,0,0,0,0] but to no avail