I spent a bit time trying to understand this function because of a wrong assumption, here is my explanation line by line for future reference:
This function calculates precision out of k classes, if you have 10 classes in your classification and k=6 then your classifier will have 50% precision@6 if it can predict 3 classes correctly.
# INPUTS: output have shape of [batch_size, category_count]
# and target in the shape of [batch_size] * there is only one true class for each sample
# topk is tuple of classes to be included in the precision
# topk have to a tuple so if you are giving one number, do not forget the comma
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
#we do not need gradient calculation for those
with torch.no_grad():
#we will use biggest k, and calculate all precisions from 0 to k
maxk = max(topk)
batch_size = target.size(0)
#topk gives biggest maxk values on dimth dimension from output
#output was [batch_size, category_count], dim=1 so we will select biggest category scores for each batch
# input=maxk, so we will select maxk number of classes
#so result will be [batch_size,maxk]
#topk returns a tuple (values, indexes) of results
# we only need indexes(pred)
_, pred = output.topk(input=maxk, dim=1, largest=True, sorted=True)
# then we transpose pred to be in shape of [maxk, batch_size]
pred = pred.t()
#we flatten target and then expand target to be like pred
# target [batch_size] becomes [1,batch_size]
# target [1,batch_size] expands to be [maxk, batch_size] by repeating same correct class answer maxk times.
# when you compare pred (indexes) with expanded target, you get 'correct' matrix in the shape of [maxk, batch_size] filled with 1 and 0 for correct and wrong class assignments
correct = pred.eq(target.view(1, -1).expand_as(pred))
""" correct=([[0, 0, 1, ..., 0, 0, 0],
[1, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 1, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 1, 0, ..., 0, 0, 0]], device='cuda:0', dtype=torch.uint8) """
res = []
# then we look for each k summing 1s in the correct matrix for first k element.
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
Correct me if you see any mistake.