These two lines of code are from the testing of a CNN model. I know of alternative ways of getting the predictions and correct predictions
but I have struggled to make sense of the two lines below:
pred = output.data.max(1, keepdim=True)[1]
correct += pred.eq(target.data.view_as(pred)).sum()
So at a high level - what your code is doing is first getting the predicted tensor pred
, and then element-wise comparing them to the values in tensor target
, setting them to True
if the elements match and False
if not. And then when you take the sum()
, you’re simply summing over the True
values, and that gives you the number of correct predictions.