Any reason why the following PyTorch (3s/epoch) code is so much slower than MXNet's version (~0.6s/epoch)?

MNIST is incredibly small and not very representative of other datasets. So things that will make your MNIST training the fastest possible won’t be the same things that make say ImageNet training the fastest possible.

Here are some recommendations:

  1. Increase the number of workers. Currently, you’re very data loading bound.
  2. Set torch.backends.cudnn.benchmark = True. (See What does torch.backends.cudnn.benchmark do?)

Really, for MNIST, you should just put the entire dataset on the GPU to start and do your data normalization once. But that doesn’t transfer to other, larger datasets.