Argmax with PyTorch

you can now do torch.argmax(preds, dim=1) in version 0.4.0

@BlakeWest dimension 0 is the batch and dimension 1 is the class probabilities (assuming you use softmax on your final output). Therefore you would want to to do an argmax along dimension 1 ie. the class with the highest probabilities

7 Likes