Top k error calculation

import torch
import torchvision
import torchvision.transforms as transforms
#from torch.autograd import Variable

transform=transforms.Compose([transforms.Resize(256),transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])

trainset = torchvision.datasets.CIFAR10(root=’./data’, train=True,
download=True, transform=transform)

trainloader=torch.utils.data.DataLoader(trainset,batch_size=4,shuffle=True,
num_workers=2)

testset = torchvision.datasets.CIFAR10(root=’./data’, train=False,
download=True, transform=transform)

testloader=torch.utils.data.DataLoader(testset,batch_size=4,shuffle=False,
num_workers=2)

classes=(‘plane’,‘car’,‘bird’,‘cat’,‘deer’,‘dog’,‘frog’,‘horse’,‘ship’,‘truck’)

import torch.nn as nn
import torch.nn.functional as F

device=torch.device(“cuda:0” if torch.cuda.is_available() else “cpu”)

import torchvision.models as models
alexnet=models.alexnet(pretrained=True)
alexnet.classifier[6]=nn.Linear(4096,10)
print(alexnet)
print(‘Model Downloaded’)

class CifarAlexnet(nn.Module):

    def __init__(self,alexnet):
        super(CifarAlexnet,self).__init__()
        self.alexnet=alexnet

    def forward(self,x):
        return self.alexnet(x)

model=CifarAlexnet(alexnet)
print(model)

import time
import torch.optim as optim

batch_size=128
epochs=5
model.to(device)
criterion=nn.CrossEntropyLoss()
optimizer=optim.Adam(alexnet.parameters(),lr=0.0001)

for epoch in range(epochs):
running_loss=0.0
total=0
probs=0
correct_classified=0
start_time=time.time()
model.train()
for i,data in enumerate(trainloader):
inputs,labels=data
inputs,labels=inputs.to(device),labels.to(device)
optimizer.zero_grad()
outputs=model(inputs)
loss=criterion(outputs,labels)
loss.backward()
optimizer.step()
# _,predicted=torch.max(outputs.data,1)
predicted=torch.argmax(outputs,1)
total+=labels.size(0)
probs=F.softmax(outputs,1)[1]
#print(“probs=”,probs)
correct_classified+=(predicted==labels).sum().item()
running_loss+=loss.item()
if i % 200 == 199:
avg_loss=running_loss/200
print(‘Epoch:[%d, %5d] Batch: %5d loss: %.3f’ % (epoch+1,i+1,i+1,avg_loss))
running_loss=0.0
train_acc=(100correct_classified/total)
# print(“Time/epoch: {} sec”.format(time.time()-start_time))
# train_acc=(correct_classified/total)
# print(‘Train Accuracy :%d’%(train_acc))
correct_1=0.0
correct_5=0.0
correct=0
c=0
total=0
model.eval()
with torch.no_grad():
for data in testloader:
images,labels=data
inputs,labels=images.to(device),labels.to(device)
#img=Variable(images).cuda()
#lab=Variable(labels).cuda()
outputs=model(inputs)
# _,predicted=torch.max(outputs.data,1)
predicted=torch.argmax(outputs,1)
total+=labels.size(0)
c+=(predicted==labels).sum().item()
#outc=model(img)
pred=outputs.topk(5,1,largest=True,sorted=True)
#pred=pred.t()
lab=labels.view(labels.size(0),-1).expand_as(pred)
# correct=pred.eq(lab).float()
correct=pred.eq(lab).float()
correct_5+=correct[:,:5].sum()
correct_1+=correct[:,:1].sum()
test_acc=(100
c/total)

    print("c=",c," total=",total,"correct_1=",correct_1,"correct_5=",correct_5)
    print('Accuracy of the network on test images:%.3f'%(test_acc))
    print('Top 1 error:%2.2f' % (1-correct_1/total))
    print('Top 5 error:%2.2f' % (1-correct_5/total))

print(‘trained’)

Error: Traceback (most recent call last):
File “AlexCIFAR10.py”, line 107, in
lab=labels.view(labels.size(0),-1).expand_as(pred)
TypeError: expand_as(): argument ‘other’ (position 1) must be Tensor, not torch.return_types.topk

I want to calculate the top 1 error and top 5 error but I’m getting this error how do I correct it

@ptrblck help me please

That line returns both values and indexes. Do pred=outputs.topk(5,1,largest=True,sorted=True)[0] to only get the values (although I haven’t looked at your code)

answer is here: compute top1, top5 error using pytorch · GitHub

with detailed explanation:


def accuracy(output: torch.Tensor, target: torch.Tensor, topk=(1,)) -> List[torch.FloatTensor]:
    """
    Computes the accuracy over the k top predictions for the specified values of k
    In top-5 accuracy you give yourself credit for having the right answer
    if the right answer appears in your top five guesses.

    ref:
    - https://pytorch.org/docs/stable/generated/torch.topk.html
    - https://discuss.pytorch.org/t/imagenet-example-accuracy-calculation/7840
    - https://gist.github.com/weiaicunzai/2a5ae6eac6712c70bde0630f3e76b77b
    - https://discuss.pytorch.org/t/top-k-error-calculation/48815/2
    - https://stackoverflow.com/questions/59474987/how-to-get-top-k-accuracy-in-semantic-segmentation-using-pytorch

    :param output: output is the prediction of the model e.g. scores, logits, raw y_pred before normalization or getting classes
    :param target: target is the truth
    :param topk: tuple of topk's to compute e.g. (1, 2, 5) computes top 1, top 2 and top 5.
    e.g. in top 2 it means you get a +1 if your models's top 2 predictions are in the right label.
    So if your model predicts cat, dog (0, 1) and the true label was bird (3) you get zero
    but if it were either cat or dog you'd accumulate +1 for that example.
    :return: list of topk accuracy [top1st, top2nd, ...] depending on your topk input
    """
    with torch.no_grad():
        # ---- get the topk most likely labels according to your model
        # get the largest k \in [n_classes] (i.e. the number of most likely probabilities we will use)
        maxk = max(topk)  # max number labels we will consider in the right choices for out model
        batch_size = target.size(0)

        # get top maxk indicies that correspond to the most likely probability scores
        # (note _ means we don't care about the actual top maxk scores just their corresponding indicies/labels)
        _, y_pred = output.topk(k=maxk, dim=1)  # _, [B, n_classes] -> [B, maxk]
        y_pred = y_pred.t()  # [B, maxk] -> [maxk, B] Expects input to be <= 2-D tensor and transposes dimensions 0 and 1.

        # - get the credit for each example if the models predictions is in maxk values (main crux of code)
        # for any example, the model will get credit if it's prediction matches the ground truth
        # for each example we compare if the model's best prediction matches the truth. If yes we get an entry of 1.
        # if the k'th top answer of the model matches the truth we get 1.
        # Note: this for any example in batch we can only ever get 1 match (so we never overestimate accuracy <1)
        target_reshaped = target.view(1, -1).expand_as(y_pred)  # [B] -> [B, 1] -> [maxk, B]
        # compare every topk's model prediction with the ground truth & give credit if any matches the ground truth
        correct = (y_pred == target_reshaped)  # [maxk, B] were for each example we know which topk prediction matched truth
        # original: correct = pred.eq(target.view(1, -1).expand_as(pred))

        # -- get topk accuracy
        list_topk_accs = []  # idx is topk1, topk2, ... etc
        for k in topk:
            # get tensor of which topk answer was right
            ind_which_topk_matched_truth = correct[:k]  # [maxk, B] -> [k, B]
            # flatten it to help compute if we got it correct for each example in batch
            flattened_indicator_which_topk_matched_truth = ind_which_topk_matched_truth.reshape(-1).float()  # [k, B] -> [kB]
            # get if we got it right for any of our top k prediction for each example in batch
            tot_correct_topk = flattened_indicator_which_topk_matched_truth.float().sum(dim=0, keepdim=True)  # [kB] -> [1]
            # compute topk accuracy - the accuracy of the mode's ability to get it right within it's top k guesses/preds
            topk_acc = tot_correct_topk / batch_size  # topk accuracy for entire batch
            list_topk_accs.append(topk_acc)
        return list_topk_accs  # list of topk accuracies for entire batch [topk1, topk2, ... etc]

ref:
- torch.topk — PyTorch 1.8.0 documentation
- ImageNet Example Accuracy Calculation
- compute top1, top5 error using pytorch · GitHub
- Top k error calculation - #2 by Oli
- python - how to get top k accuracy in semantic segmentation using pytorch - Stack Overflow