@MariosOreo Thanks for the catch! I’ve fixed it in my post.
@Deb_Prakash_Chatterj You could count it manually of create a confusion matrix first.
Based on the confusion matrix you could then calculate the stats.
Here is a small example. I tried to validate the results, but you should definitely have another look at it:
nb_samples = 20
nb_classes = 4
output = torch.randn(nb_samples, nb_classes)
pred = torch.argmax(output, 1)
target = torch.randint(0, nb_classes, (nb_samples,))
conf_matrix = torch.zeros(nb_classes, nb_classes)
for t, p in zip(target, pred):
conf_matrix[t, p] += 1
print('Confusion matrix\n', conf_matrix)
TP = conf_matrix.diag()
for c in range(nb_classes):
idx = torch.ones(nb_classes).byte()
idx[c] = 0
# all non-class samples classified as non-class
TN = conf_matrix[idx.nonzero()[:, None], idx.nonzero()].sum() #conf_matrix[idx[:, None], idx].sum() - conf_matrix[idx, c].sum()
# all non-class samples classified as class
FP = conf_matrix[idx, c].sum()
# all class samples not classified as class
FN = conf_matrix[c, idx].sum()
print('Class {}\nTP {}, TN {}, FP {}, FN {}'.format(
c, TP[c], TN, FP, FN))