How should I compute the accuracy for a multilable dataset?

I am wondering how should I get the accuracy for a multi-lable classification task?
Can someone please provide a toy example? :slight_smile:

For multiclass classification you should have an output tensor of size (batch, num_classes) while the target label tensor should be (a LongTensor) of size (batch), where batch is the number of samples.

To compute accuracy you should first compute a softmax in order to have probabilities of each class for each sample, i.e.:

probs = torch.softmax(out, dim=1)

Then you should select the most probable class for each sample, i.e.:

winners = probs.argmax(dim=1)

Now you can compare target with winners:

corrects = (winners == target)

And take the average over all samples:

accuracy = corrects.sum().float() / float( target.size(0) )

If you want to do this all in one you can do it like:

accuracy = (torch.softmax(out, dim=1).argmax(dim=1) == target).sum().float() / float( target.size(0) )
5 Likes

Lets for the simplicity we say we have one batch in each time. and you have 20 classes.

so the output is gonna be in shape of torch.Size([1, 20]), am I correct?

But in multi lable classification you might have multi class in one time, when you do winners = probs.argmax(dim=1) you are considering just one class that I dont think is correct.

Sorry, I answered a completely different question. My answer is for multiclass classification, while you asked for multilabel classification. My mistake!

A brute-force way to approach your original problem, multilabel classification, is to transform it into multiclass classification by creating all possible combination of the labels and consider each combination a single class.
In this way you can use the method I explained above, but you have to consider the powerset of the set of all labels, which can be huge.

Although your answer if for Multiclass class classification different from what was asked in the question, it really helped a lot as I was searching how to compute the accuracy for multiclass classification.

2 Likes

Why do the softmax before the argmax instead of doing argmax directly on the logits? The answer shouldn’t change…