For some reason, torch.argmax is slower for me than transferring an array to CPU and then calling np.argmax. Any ideas why? Should I file a bug report?
That sounds bad, please file an issue.
Thanks, filed: https://github.com/pytorch/pytorch/issues/8817