Hello everyone,
We, the Algorithms Working Group of MLCommons, are developing a large-scale competitive benchmark for training algorithms (e.g. Adam, SGD, Lion, VeLO, etc.) with a $50,000 prize pool. We hope to support submissions in both the PyTorch and JAX frameworks. However, we’re finding it extremely challenging to get our PyTorch code to be as fast as our JAX code.
Specifically, our benchmark involves running submitted training algorithms on eight different deep learning workloads on a fixed 8xV100 system. A workload is essentially “a deep learning training problem”, consisting of a dataset, a model, a loss function, and a validation/test error goal. For 3 out of our 8 workloads, we have managed to achieve comparable training wall-clock runtimes between JAX and PyTorch (namely, ImageNet ResNet, ImageNet ViT, and LibriSpeech Conformer, the runtime difference is within 5%). Another three workloads (OGBG GNN, fastMRI U-Net, and LibriSpeech DeepSpeech) have somewhat comparable training wall-clock runtimes (with a runtime difference of <12%) but we still see PyTorch consistently slower. However the most problematic is that, for the remaining two workloads (Criteo DLRMsmall and WMT Transformer), our PyTorch implementations seem significantly slower than our JAX implementations (>25%):
- Criteo DLRMsmall: We are currently struggling with an OOM issue in our PyTorch 2 implementation that is not present in our JAX version. Using PyTorch 1.13, however, PyTorch was >60% slower than JAX.
- fastMRI U-Net: 10% slower than JAX (10% slower with PyTorch 1.13)
- LibriSpeech DeepSpeech: 12% slower than JAX (17% slower with PyTorch 1.13)
- Timing was done without torch.compile due to the torchdynamo error.
- OGBG GNN: 10% slower than JAX (30% slower with PyTorch 1.13)
- Timing was done without torch.compile due to the torchdynamo error (for using GraphsTuple as inputs).
- WMT Transformer: 25% slower than JAX (21% slower with PyTorch 1.13)
- Timing was done without torch.compile due to the torchdynamo error (on boolean masking).
All timings use PyTorch 2 (timing results for PyTorch 1.13 are shown in parentheses) and were run on the same 8xV100 machine using the AdamW optimizer with identical hyperparameters and CUDA version 11.8. The code for the individual workloads can be found here in our codebase
Another requirement of our workload implementations is that they produce exactly the same results between the two frameworks. Therefore, we have tried to use the same data loading and preprocessing code as much as possible. Still, there are some challenges because it is more idiomatic for our JAX implementations to run only a single Python process. In contrast, our PyTorch implementations use DistributedDataParallel training with a Python process for each of the 8 GPUs that our benchmark requires.
We’re planning to issue a Call for Submission for our competition around September 1st. However, with the current speed differences between PyTorch and JAX, submissions in PyTorch will not be competitive. People will most likely not submit in PyTorch because the slowdown will significantly reduce their chance at the $50,000 prize pool.
We could really use any help possible from the PyTorch community since we are a small team with only a handful of PyTorch users performing engineering work. We are looking for suggestions on how to speed up our PyTorch implementations and make them competitive with our JAX versions (even better if these suggestions are feasible to implement before we issue the call for submissions). We can also use more extensive engineering contributions or new additions to our working group if anyone wants to join.
Thank you.