How to Create a Multilabel Confusion Matrix for 14 Disease Classes in PyTorch?

0

I’m working on a multilabel classification task with 14 different disease classes. I’ve trained my model, and I want to generate a single multilabel confusion matrix where both the x-axis and y-axis represent the 14 classes.

However, when I try to generate the confusion matrix using my current code, it creates a separate confusion matrix for each class. Instead, I would like a unified confusion matrix where the true labels and predicted labels are across the same 14 classes on both axes.

Here are the key details of my setup:

  • I’m using PyTorch.
  • I have a trained model and access to train_loader.
  • I’ve successfully used my model for prediction, but I’m stuck on how to aggregate the results into a single confusion matrix for multilabel classification.

Here’s the code I’m using to generate the confusion matrix:

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import multilabel_confusion_matrix

# Initialize lists to store true labels and predictions
all_labels = []
all_predictions = []

# Disable gradient calculation for inference
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)

        # Forward pass to get model predictions
        outputs = best_model(images)

        # Store true labels and predicted probabilities
        all_labels.extend(labels.cpu().numpy())
        all_predictions.extend(outputs.cpu().numpy())  # Get the raw output (probabilities)

# Convert to numpy arrays for easier manipulation
all_labels = np.array(all_labels)
all_predictions = np.array(all_predictions)

# Apply thresholding to convert probabilities to binary predictions
binary_predictions = (all_predictions > 0.5).astype(int)

# Compute the multilabel confusion matrix
confusion_mtx = multilabel_confusion_matrix(all_labels, binary_predictions)

# Function to plot the multilabel confusion matrix
def plot_multilabel_confusion_matrix(confusion_mtx, class_names):
    num_classes = confusion_mtx.shape[0]
    ncols = 3  # Set the number of columns for the plot
    nrows = (num_classes + ncols - 1) // ncols  # Calculate the number of rows needed

    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols * 4, nrows * 4))
    axes = axes.flatten()  # Flatten the 2D array of axes for easy iteration

    for i in range(num_classes):
        ax = axes[i]
        ax.matshow(confusion_mtx[i], cmap=plt.cm.Blues, alpha=0.5)
        ax.set_xlabel('Predicted')
        ax.set_ylabel('True')
        ax.set_title(class_names[i])

        # Set x and y axis ticks to show "Positive" first and "Negative" second
        ax.set_xticks([0, 1])
        ax.set_xticklabels(['Positive', 'Negative'])  # Positive first
        ax.set_yticks([0, 1])
        ax.set_yticklabels(['Positive', 'Negative'])  # Positive first

        # Show the counts
        for j in range(confusion_mtx[i].shape[0]):
            for k in range(confusion_mtx[i].shape[1]):
                ax.text(k, j, confusion_mtx[i][j, k], ha='center', va='center')

    # Hide any unused subplots
    for i in range(num_classes, len(axes)):
        axes[i].axis('off')

    plt.tight_layout()
    plt.show()

# Plot the multilabel confusion matrix
plot_multilabel_confusion_matrix(confusion_mtx, class_names)

What I get is 14 separate confusion matrices, but I need a single confusion matrix with all 14 classes represented on both axes.

Here’s an image of what I’m getting:

And here’s what I want:

Try the below, you might need to adjust all_labels and all_predictions. I used random data.

import numpy as np
import matplotlib.pyplot as plt

num_samples = 100
num_classes = 14

np.random.seed(42)

all_labels = np.random.randint(0, 2, size=(num_samples, num_classes)) # just an example
all_predictions = np.random.randint(0, 2, size=(num_samples, num_classes)) # just example
class_names = [f'Class {i+1}' for i in range(num_classes)]
confusion_mtx = np.zeros((num_classes, num_classes), dtype=int)

for i in range(num_samples):
    true_labels = np.where(all_labels[i] == 1)[0]
    pred_labels = np.where(all_predictions[i] == 1)[0]
    for t in true_labels:
        for p in pred_labels:
            confusion_mtx[t, p] += 1

def plot_confusion_matrix(cm, class_names):
    fig, ax = plt.subplots(figsize=(10, 10))
    im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    ax.figure.colorbar(im, ax=ax)
    ax.set(
        xticks=np.arange(cm.shape[1]),
        yticks=np.arange(cm.shape[0]),
        xticklabels=class_names,
        yticklabels=class_names,
        xlabel='Predicted Label',
        ylabel='True Label',
        title='Confusion Matrix'
    )
    plt.setp(ax.get_xticklabels(), rotation=45, ha='right')
    fmt = 'd'
    thresh = cm.max() / 2
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            ax.text(
                j, i, format(cm[i, j], fmt),
                ha='center', va='center',
                color='white' if cm[i, j] > thresh else 'black'
            )
    plt.tight_layout()
    plt.show()

plot_confusion_matrix(confusion_mtx, class_names)