Help me with the reasoning because I am at the moment a bit lost since I saw many different approaches.
I need to calculate accuracy the best possible way, let’s use some of the predefined datasets in PyTorch like cifar10, cifar100, … so single label classification problem.
I saw many different approaches. What is the best approach.
approach 1 (keep track of correct and total labels on cpu)
I think this approach is not optimal because you need to convert to cpu to some numpy array possible the number of correct examples. Why not keeping it on gpu.
approach 2 (using the formula for a single batch)
It may be just the correct labels we need. The totals we may get based on the batch enumerator and batch size.
However, this is a single batch accuracy formula.
def accuracy(preds, true):
preds = preds.argmax(dim=-1)
This can work for all full dataset if we concatentate preds and true arrays.
I also think this is a wrong approach for cifar10, or ciffar100 because just keeping the single number of
correct is enough.
You are welcome to comment.
I cannot see any benefits to keep just a number on GPU. There’s no heavy computation. So approach 1 is already good.
Thanks @KaiHoo, just keeping correct and total labels on CPU is OK, but you do it for all batches like that. On the other hand the formula in approach 2 I saw in here and I slightly modified it. It is good if you need per batch accuracy, but more important is to have per epoch accuracy. Is there any nice example for cifar10 that uses nice accuracy metrics practice, I would be happy if someone share.
For the epoch accuracy do you just average the per batch accuracy with
drop_last=False or you keep the correct count?
What about this one:
correct = total = 0
for batch_idx, (inputs, targets) in enumerate(trainloader):
inputs, targets = inputs.to(device), targets.to(device)
outputs = model(inputs)
loss = criterion(outputs, targets)
train_loss += loss.item()
_, predicted = outputs.max(1)
if batch accuracy required
total_batch = targets.size(0)
correct_batch = predicted.eq(targets).sum().item()
batch_acc = total_batch / correct_batch
correct += predicted.eq(targets).sum().item()
total += targets.size(0)