How I can use Ignite as a metric class in my training loop

Dear friends,

How I can use Ignite just as a metric class in my training loop to calculate the (accuracy, precision, recall)?


If you are using Ignite for training, have a look at the Quickstart guide showing an example usage of some metrics.
Or would you like to use some Ignite snippets in isolation?

1 Like

Dear ptrblck,

Thank you for replay. Actually, the Accuracy and Loss class worked fine with me, but I can’t calculate the precision, recall, and F1 as isolated classes in my training loop.

I tried as below code:

def get_precision(output, trg):

    #output = torch.tensor(output) #predicted
    #trg = torch.tensor(trg) #output
    precision = Precision(output_transform=thresholded_output_transform, average=False)
    #binary_accuracy = Accuracy(thresholded_output_transform)
    precision.update((output, trg))
    epoch_precision = precision.compute()
    return epoch_precision

and I called in the training section like

 precision = get_precision(output, trg)

I get a tensor array like this

tensor([0.4250, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],

I think this is not a correct value to use directly.

The training class as :

ef train(model, iterator, optimizer, criterion, clip):
    epoch_loss = 0
    epoch_accuracy  = 0 
    epoch_precision = 0 
    for i, batch in enumerate(iterator):
        src = batch.src
        trg = batch.trg
        output = model(src, trg[:,:-1])

        output = output.contiguous().view(-1, output.shape[-1])
        trg = trg[:,1:].contiguous().view(-1)       

        loss = criterion(output, trg)
        accuracy = get_accuracy (output, trg)
        precision = get_precision(output, trg)
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        epoch_loss += loss.item()
        epoch_accuracy  +=  accuracy 
        epoch_precision += precision 
    return epoch_loss / len(iterator), epoch_accuracy  / len(iterator), epoch_precision / len(iterator)

I found the solution, just I set the average value = True.