Top-k accuracy for CNN with batches

Is there a way to efficiently calculate top-k accuracy in Pytorch when using batches for CNNs?

Currently I use a scikitlearn method called accuracy_score which takes an array of argmax values (so one value in the array is the argmax of an image prediction by the CNN) and compares it to an array of target values (where each element is an image target).

Getting the prediction and target:

prediction_by_CNN = model(batch)
pred_numpy = prediction_by_CNN.detach().cpu().numpy()
target_numpy = target.detach().cpu().numpy()
prediction = np.argmax(pred_numpy,axis=1)
prediction = np.round(prediction).astype(np.unit8).reshape(-1)
target = target.astype(np.unit8).reshape(-1)

Calculating the accuracy:

accuracies.append(accuracy_score(target, prediction)*100)
mean_batch_accuracy_score = sum(accuracies)/len(accuracies)
epoch_accuracy.append(mean_batch_accuracy)

The percentage of correct predictions is then calculated but this is equivalent to a top-1 accuracy only.

Is there a way to alter the scikitlearn method or maybe make a function that will efficitently calculate the top-3 accuracy?

My possible solution:

  1. Catch the prediction arrays before performing argmax()
  2. Use something like: prediction_numpy_top_k = (np.argpartition(prediction_numpy,-3,axis=1)[:-3]
  3. Perform: if target_numpy in prediction_numpy_top_k == True: \\ accuracies.append(1) \\ else: accuracies.append(0)

or maybe this could be performed before even casting to numpy format?

Many thanks in advance!