Dear friends,
How I can use Ignite just as a metric class in my training loop to calculate the (accuracy, precision, recall)?
Regards,
Aiman
Dear friends,
How I can use Ignite just as a metric class in my training loop to calculate the (accuracy, precision, recall)?
Regards,
Aiman
If you are using Ignite for training, have a look at the Quickstart guide showing an example usage of some metrics.
Or would you like to use some Ignite snippets in isolation?
Dear ptrblck,
Thank you for replay. Actually, the Accuracy and Loss class worked fine with me, but I can’t calculate the precision, recall, and F1 as isolated classes in my training loop.
I tried as below code:
def get_precision(output, trg):
#output = torch.tensor(output) #predicted
#trg = torch.tensor(trg) #output
precision = Precision(output_transform=thresholded_output_transform, average=False)
#binary_accuracy = Accuracy(thresholded_output_transform)
precision.update((output, trg))
epoch_precision = precision.compute()
return epoch_precision
and I called in the training section like
precision = get_precision(output, trg)
I get a tensor array like this
tensor([0.4250, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
dtype=torch.float64)
I think this is not a correct value to use directly.
The training class as :
ef train(model, iterator, optimizer, criterion, clip):
model.train()
epoch_loss = 0
epoch_accuracy = 0
epoch_precision = 0
for i, batch in enumerate(iterator):
src = batch.src
trg = batch.trg
optimizer.zero_grad()
output = model(src, trg[:,:-1])
output = output.contiguous().view(-1, output.shape[-1])
trg = trg[:,1:].contiguous().view(-1)
loss = criterion(output, trg)
accuracy = get_accuracy (output, trg)
precision = get_precision(output, trg)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
optimizer.step()
epoch_loss += loss.item()
epoch_accuracy += accuracy
epoch_precision += precision
return epoch_loss / len(iterator), epoch_accuracy / len(iterator), epoch_precision / len(iterator)
I found the solution, just I set the average value = True.