The CNN sample code here Writing CNNs from Scratch in PyTorch
has a snippet where they calculate the accuracy of the model
with torch.no_grad():
correct = 0
total = 0
for images, labels in train_loader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy of the network on the {} train images: {} %'.format(50000, 100 * correct / total))
I don’t understand the ‘_, predicted = torch.max(outputs.data, 1)’ line.
I know outputs.data has the output data for the batch but why
necessary to get torch.max? Also, what does the 2nd arg of “1” in torch.max do?
Thanks!
Chris