you can now do torch.argmax(preds, dim=1) in version 0.4.0
@BlakeWest dimension 0 is the batch and dimension 1 is the class probabilities (assuming you use softmax on your final output). Therefore you would want to to do an argmax along dimension 1 ie. the class with the highest probabilities