# How can I plot confusion matrix for a multiclass multilabel problem in a better way than this?

Also, how can I modify the accuracy calculation, so it make more sense?
Here is my code:

my_metrics = get_metrics(pred, label, nb_classes=label.shape[1])
cm = my_metrics.pop(‘c_mat’)
pyplot.figure(figsize=(9,9))
sn.heatmap(cm, cmap=‘Blues’, annot=True, fmt=".2f", center=None, vmin=0)

def get_metrics(pred, target, nb_classes=14):

``````## 0. Pre-process

tanh = nn.Sigmoid()
preds_bin = (tanh(pred) > 0.5).long()          # one-hot-enc matrix [S DOAs]

## 1. Metrics
preds_bin_noSilence = preds_bin[:,1:]           # one-hot-enc matrix [DOAs]
labels_noSilence = target[:,1:]                      # one-hot-enc matrix [DOAs]

preds_indx  = torch.flatten( preds_bin * torch.arange(0, nb_classes, 1) ).long()      # index very long vector
labels_indx = torch.flatten( target    * torch.arange(0, nb_classes, 1) ).long()      # index very long vector

conf_matrix = torch.zeros(nb_classes, nb_classes)
for t, p in zip(labels_indx, preds_indx):
conf_matrix[t, p] += 1

TP = conf_matrix.diag()

# accuracy = []
# precision = []
# recall = []

for c in range(nb_classes):
idx = torch.ones(nb_classes).byte()
idx[c] = 0
# all non-class samples classified as non-class
TN = conf_matrix[idx.nonzero()[:, None], idx.nonzero()].sum() #conf_matrix[idx[:, None], idx].sum() - conf_matrix[idx, c].sum()
# all non-class samples classified as class
FP = conf_matrix[idx, c].sum()
# all class samples not classified as class
FN = conf_matrix[c, idx].sum()

# print(TN.shape, TN)
# input('dfgdg')

accuracy = (TP + TN) / (TP + FP + TN + FN)
precision = TP / (TP + FP)
recall = TP / (TP + FN)

## 3. Return
return {"acc": accuracy.numpy(),       # 80
"prec": precision.numpy(),     # 0/0
"recall": recall.numpy(),      # 0
"c_mat": conf_matrix.numpy()}``````

Based on your code it seems that you are creating a “standard” confusion matrix, which shows the confusion between every two classes. For a multi-label classification you might want to check e.g. `sklearn.metrics.multilabel_confusion_matrix`.