Discrepancy Between AUPRC in torch and auc under precision_recall_curve in sklearn

Hello Everyone,
I am writing a script that extends torch metrics to gives me some additional ease in using AUPRC and and AUROC in torch. Essentially I want them to be wrapped with a task parameter that allows me to select the task on init and an average parameter that allows me to select the micro average as in sklearns implementation (roc_auc_score — scikit-learn 1.5.0 documentation).

So while I did achieve that the problem is the following:

  • The values computed for auc pr are different in torchmetrics and sklearn

I can’t really explain why that is, since torch does not take thresholds as a paramter I can’t really increase them to a value where they wouldn’t matter. You can find the script here:

from torch import Tensor
from typing import Type, Literal, Optional, Union, List, Any
from torcheval.metrics import BinaryAUPRC, MulticlassAUPRC, MultilabelAUPRC, metric
from torchmetrics import AUROC as _AUROC
from copy import deepcopy

# TODO! This absolutetly needs testing


class AUPRC(metric.Metric[Tensor]):

    def __new__(cls, task: str, num_labels: int = 1, average="macro"):
        if average not in ["macro", "micro", "none"]:
            raise ValueError("Average must be one of 'macro', 'micro', or 'none'"
                             f" but is {average}")

        if task == "binary" or average == "micro":
            metric = BinaryAUPRC()
        elif task == "multiclass":
            # Some debate in the net but in torch this is one-vs-all
            metric = MulticlassAUPRC(num_classes=num_labels, average=average)
        elif task == "multilabel":
            # This is multiple positives allowed
            metric = MultilabelAUPRC(num_labels=num_labels, average=average)
        else:
            raise ValueError("Unsupported task type or activation function")
        metric._task = task
        metric._average = average

        return metric

    def update(self, predictions, labels):
        # Reshape predictions and labels to handle the batch dimension
        if self._task == "binary" or self._average == "micro":
            predictions = predictions.view(-1)
            labels = labels.view(-1)
        elif self._task == "multiclass":
            labels = labels.view(-1)

        self.metric.update(predictions, labels)

    def to(self, device):
        # Move the metric to the specified device
        self.metric = self.metric.to(device)
        return self


class AUROC(_AUROC):

    def __new__(
        cls: Type["_AUROC"],
        task: Literal["binary", "multiclass", "multilabel"],
        thresholds: Optional[Union[int, List[float], Tensor]] = None,
        num_classes: Optional[int] = None,
        average: Optional[Literal["macro", "weighted", "none", "micro"]] = "macro",
        max_fpr: Optional[float] = None,
        ignore_index: Optional[int] = None,
        validate_args: bool = True,
        **kwargs: Any,
    ):
        if average == "micro" and task == "multilabel":
            task = "binary"

        metric = super().__new__(cls,
                                 task=task,
                                 thresholds=thresholds,
                                 num_classes=num_classes,
                                 num_labels=num_classes,
                                 average="none" if average == "micro" else average,
                                 max_fpr=max_fpr,
                                 ignore_index=ignore_index,
                                 validate_args=validate_args,
                                 **kwargs)
        metric._average = average
        return metric

    # You might want to override update and compute methods if needed
    def update(self, input: Tensor, target: Tensor, weight: Tensor = None, *args, **kwargs) -> None:
        if self._average == "micro":
            target = target.view(-1)
            input = input.view(-1)
        return self.update(input, target, weight, *args, **kwargs)

    def compute(self) -> Tensor:
        return self.compute()


if __name__ == "__main__":
    # Multi-class classification data
    import torch
    # Compute precision-recall curve
    import numpy as np
    from sklearn.metrics import roc_auc_score, precision_recall_curve, auc

    #
    y_true_multi = torch.Tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [0, 1, 0], [0, 0, 1],
                                 [1, 0, 0], [0, 0, 1], [1, 0, 0], [1, 0, 1]]).int()
    y_pred_multi = torch.Tensor([[0.8, 0.1, 0.1], [0.2, 0.7, 0.1], [0.1, 0.3, 0.6], [0.7, 0.2, 0.1],
                                 [0.1, 0.6, 0.3], [0.2, 0.1, 0.7], [0.6, 0.3, 0.1], [0.3, 0.5, 0.2],
                                 [0.2, 0.1, 0.7], [0.7, 0.2, 0.1]])

    # ---------------- Comparing Micro-Macro PR AUC using torch with sklearn --------------------
    print("--- Comparing Micro-Macro ROC AUC ---")

    # Compute ours: micro
    micro_rocauc = AUROC(task="multilabel", average="micro", num_classes=3)
    micro_rocauc.update(y_pred_multi, y_true_multi)
    print("Micro AUCROC (torch):", micro_rocauc.compute())

    # Compute ours: macro
    macro_rocauc = AUROC(task="multilabel", average="macro", num_classes=3)
    macro_rocauc.update(y_pred_multi, y_true_multi)
    print("Micro AUCPRC (torch):", macro_rocauc.compute())

    # Compute theirs
    # Flatten y_true_multi as numpy
    y_true_multi_flat = y_true_multi.numpy().flatten()
    y_pred_multi_flat = y_pred_multi.numpy().flatten()

    # Compute micro-average ROC AUC using sklearn
    micro_rocauc_sklearn = roc_auc_score(y_true_multi,
                                         y_pred_multi,
                                         average='micro',
                                         multi_class='ovr')
    print(f'Micro-average auc-roc (sklearn): {micro_rocauc_sklearn:.4f}')

    # Compute macro-average ROC AUC using sklearn
    macro_rocauc_sklearn = roc_auc_score(y_true_multi,
                                         y_pred_multi,
                                         average='macro',
                                         multi_class='ovr')
    print(f'Macro-average auc-roc (sklearn): {macro_rocauc_sklearn:.4f}')

    # ---------------- Comparing Micro-Macro PR AUC using torch with sklearn --------------------
    print("--- Comparing Micro-Macro PR AUC ---")
    micro_prauc = AUPRC(task="multilabel", num_labels=3, average="micro")
    macro_prauc = AUPRC(task="multilabel", num_labels=3, average="macro")

    # Compute ours
    for idx in range(len(y_true_multi)):
        yt = y_true_multi[idx, :].unsqueeze(0)
        yp = y_pred_multi[idx, :].unsqueeze(0)
        micro_prauc.update(yp, yt)
        macro_prauc.update(yp, yt)

    print("Micro AUCPR (torch):", micro_prauc.compute())
    print("Macro AUCPR (torch):", macro_prauc.compute())

    # Compute theirs
    roc_pr_list = []
    roc_auc_list = []

    # Iterate over each class
    for i in range(y_true_multi.shape[1]):
        y_true = y_true_multi[:, i]
        y_pred = y_pred_multi[:, i]

        # Compute precision-recall curve
        precision, recall, _ = precision_recall_curve(y_true, y_pred)
        roc_pr_list.append(auc(recall, precision))

        # Compute ROC AUC score
        roc_auc = roc_auc_score(y_true, y_pred)
        roc_auc_list.append(roc_auc)

    print(f"PR AUC macro Score (sklearn): {np.mean(roc_pr_list)}")

    precision, recall, _ = precision_recall_curve(y_true_multi_flat, y_pred_multi_flat)
    pr_auc = auc(recall, precision)

    # Print results
    print(f"PR AUC micro Score (sklearn): {pr_auc}")
    print()
    # ---------------- Comparing Binary PR AUC using torch with sklearn --------------------
    prauc = AUPRC(task="binary")
    prauc.update(y_pred_multi.flatten(), y_true_multi.flatten())

    print("Binary AUCPR (torch):", prauc.compute())
    from torcheval.metrics.functional import binary_auprc
    binary_auprc(y_pred_multi.flatten(), y_true_multi.flatten())
    print("Binary AUCPR functional (torch):", prauc.compute())
    precision, recall, _ = precision_recall_curve(y_true_multi_flat, y_pred_multi_flat)
    pr_auc = auc(recall, precision)

    # Print results
    print(f"Binary PRAUC Score (sklearn): {pr_auc}")