I have a multi-class problem, the classes are all encoded 0-72.
I have an preds tensor of [256, 72].
Passing it through
probs = torch.nn.functional(input, dim = 1) results in a tensor
with the same dimensionality.
Where probs is a list of probabilities of each class being the correct prediction.
I would like to analyse the predictions my model is making, how can I link the probabilities to specific classes?
I cannot find anywhere how to link the probabilities to their respective classes.
Could you expand on the issue you’re having? I’m assuming the functional operation you’re doing is
softmax (you don’t have this in your code snippet). Are you trying to get an output to see which class your model thinks is the answer for each sample?
My bad for not inserting it, I can’t edit the post anymore.
I would like to analyse what classes it mistakes the most and then potentially try finding a solution for it.
outputs a tensor size:(batch_size, num_classes)
But I cannot decipher how do these probabilities relate to each of the classes!
It isn’t as simple as getting argmax() as often it is making wrong predictions.
The classes are not passed explicitly anywhere in the model (pre-trained resnet50 with finetuning of FC and last conv), and I just cannot connect each prob to its respective class.
Is CrossEntropyLoss what you’re looking for? That’s the go-to loss function for these type of problems. Or am I misinterpreting what you mean?
I am using CrossEntropyLoss (will be trying out https://github.com/vandit15/Class-balanced-loss-pytorch/blob/master/class_balanced_loss.py tomrrow) however the output of the CrossEntropyLoss is a single scalar.
By applying softmax (which you shouldn’t do before CrossEntropyLoss as it applies logmax within) we get a distribution of probabilities of an image being any of the existing classes.
Using that I can inspect which class is being predicted, and if it’s not the correct one, then how inaccurate is it (is the correct label second most-likely or is it not even considering it in the slightest).
e.g: this is the output of softmax on the first batch of epoch 0
They all sum to 1 (as probabilities should) and they represent the probability of each class being the correct answer.
But how do I link each of the probabilities to the class??
is prob == class_1 or maybe class_10?
It’s impossible to tell, and that is what I am looking for!
Actually, I just double checked it and they do not always sum to 1…
Which is very very weird in my opinion.
for i in range(7):
print(sum([float(x) for x in data["epoch_0"]["probs"][i]])):
They are not far off from 1, nonetheless, they are not ==1.
@ptrblck how come?
Is this some error on my side or is it down to how python manages small numbers?
(they’re not that small considering that python allows for up to 2.1e-308)
It’s a one-to-one mapping. So assuming your output should be batch_size x num_classes, then prob is the probability of your first sample being class 0 and so on.
As for the values not summing up to 1.0, I believe this is more of a numerical error. SoftMax is not very numerically stable due to the exponential operations. This is why CrossEntropyLoss uses logmax internally.
Would using logmax change that?
Thank you so much! It’s kind of obvious that they are, but I just wasn’t certain and didn’t want to make some wrong assumptions!
Assuming the labels weren’t encoded and I used strings, how would they be sequenced in that case?
Logmax is more of a mathematical trick due to us performing an exponential operation followed up by a log operation (cross entropy loss uses negative log likelihood). So if we know a log operation is coming up (which is the inverse of a exponential operation), we can rewrite our math in such a way where it’s more numerically stable and avoid performing those exponential operations. It’s not a direct drop-in replacement for softmax; the output of logmax won’t sum up to 1.
For your second question, do you mean, instead of having the labels [0, 1, 2], we have something like [‘apple’, ‘orange’, ‘banana’]? We would need to encode these prior to using cross-entropy loss anyways, so I’m not sure how you would even get it working with just pure strings.
I guess that answers my question.
Thanks for all the help and patience!
could you please share your solutuion to this probleam