Combining metrics in ignite.metrics

In the following (1) code block, for each metric (accuracy, precision, recall, f1), I create a metric class to record (y_pred, y) and calculate the score at the end. My question is can I create a new metric class that combines the four metrics. And I just need to update it once in a loop. For example, see the code block (2).

(1) What I use now.

from ignite.metrics import Accuracy, Precision, Recall, Fbeta

accuracy = Accuracy()
precision = Precision()
recall = Recall()
f1 = Fbeta(beta=1.0, average=False, precision=precision, recall=recall)

for X, y in dataloader:
    y_pred = model(X)

    # calculate loss, backward, and update weights

    accuracy.update((y_pred, y))
    precision.update((y_pred, y))
    recall.update((y_pred, y))

print(f"Accuracy: {accuracy.compute()}")
print(f"Precision: {precision.compute()}")
print(f"Recall: {recall.compute()}")
print(f"F1: {f1.compute()}")

(2) What I want.

import CustomMetric  # metric combining Accuracy, Precision, Recall, and F1

metric = CustomMetric()

for X, y in dataloader:
    y_pred = model(X)

    # calculate loss, backward, and update weights

    metric.update((y_pred, y))

scores = metric.compute()
print(f"Accuracy: {scores['accuracy']}")
print(f"Precision: {score['precision']}")
print(f"Recall: {score['recall']}")
print(f"F1: {scores['f1']}")

@stvhuang you can do it like this :

class MetricsGroup:
        
    def __init__(self, metrics_dict):
        self.metrics = metrics_dict
        
    def update(self, output):
        for name, metric in self.metrics.items():
            metric.update(output)
            
    def compute(self):
        output = {}
        for name, metric in self.metrics.items():
            output[name] = metric.compute()
        return output


import torch
from ignite.metrics import Accuracy, Precision, Recall, Fbeta

p = Precision()
r = Recall()
m_group = MetricsGroup({
    "accuracy": Accuracy(),
    "precision": p,
    "recall": r,
    "f1": Fbeta(beta=1.0, average=False, precision=p, recall=r)
})

for _ in range(10):
    
    y = torch.randint(0, 4, size=(32, ))
    y_pred = torch.rand(32, 4)

    m_group.update((y_pred, y))

scores = m_group.compute()
print(f"Accuracy: {scores['accuracy']}")
print(f"Precision: {scores['precision']}")
print(f"Recall: {scores['recall']}")
print(f"F1: {scores['f1']}")

HTH