Question about use of torch.max function to calculate accuracy

The CNN sample code here Writing CNNs from Scratch in PyTorch
has a snippet where they calculate the accuracy of the model

with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    
    print('Accuracy of the network on the {} train images: {} %'.format(50000, 100 * correct / total))

I don’t understand the ‘_, predicted = torch.max(outputs.data, 1)’ line.
I know outputs.data has the output data for the batch but why
necessary to get torch.max? Also, what does the 2nd arg of “1” in torch.max do?

Thanks!

Chris

Hi @seberino,

The docs for torch.max can be found here: torch.max — PyTorch 2.0 documentation

In short, torch.max returns,

“Returns a namedtuple (values, indices) where values is the maximum value of each row of the input tensor in the given dimension dim. And indices is the index location of each maximum value found (argmax).”

As you’re probably doing a classification problem (given from your code snippet), you want the class id (or the max index) of your model’s prediction rather than the value of the class id itself.

Thanks!
I’m guessing PyTorch CNN classification models do not just return the one class with the most
likely chance of being correct, but rather, the probability for each class?

And, that is why the user must use torch.max to drill down to the one best guess of the model?

Chris