Struggling to get PyTorch fast enough to use in public competition

Hello everyone,

We, the Algorithms Working Group of MLCommons, are developing a large-scale competitive benchmark for training algorithms (e.g. Adam, SGD, Lion, VeLO, etc.) with a $50,000 prize pool. We hope to support submissions in both the PyTorch and JAX frameworks. However, we’re finding it extremely challenging to get our PyTorch code to be as fast as our JAX code.

Specifically, our benchmark involves running submitted training algorithms on eight different deep learning workloads on a fixed 8xV100 system. A workload is essentially “a deep learning training problem”, consisting of a dataset, a model, a loss function, and a validation/test error goal. For 3 out of our 8 workloads, we have managed to achieve comparable training wall-clock runtimes between JAX and PyTorch (namely, ImageNet ResNet, ImageNet ViT, and LibriSpeech Conformer, the runtime difference is within 5%). Another three workloads (OGBG GNN, fastMRI U-Net, and LibriSpeech DeepSpeech) have somewhat comparable training wall-clock runtimes (with a runtime difference of <12%) but we still see PyTorch consistently slower. However the most problematic is that, for the remaining two workloads (Criteo DLRMsmall and WMT Transformer), our PyTorch implementations seem significantly slower than our JAX implementations (>25%):

  • Criteo DLRMsmall: We are currently struggling with an OOM issue in our PyTorch 2 implementation that is not present in our JAX version. Using PyTorch 1.13, however, PyTorch was >60% slower than JAX.
  • fastMRI U-Net: 10% slower than JAX (10% slower with PyTorch 1.13)
  • LibriSpeech DeepSpeech: 12% slower than JAX (17% slower with PyTorch 1.13)
    • Timing was done without torch.compile due to the torchdynamo error.
  • OGBG GNN: 10% slower than JAX (30% slower with PyTorch 1.13)
    • Timing was done without torch.compile due to the torchdynamo error (for using GraphsTuple as inputs).
  • WMT Transformer: 25% slower than JAX (21% slower with PyTorch 1.13)
    • Timing was done without torch.compile due to the torchdynamo error (on boolean masking).

All timings use PyTorch 2 (timing results for PyTorch 1.13 are shown in parentheses) and were run on the same 8xV100 machine using the AdamW optimizer with identical hyperparameters and CUDA version 11.8. The code for the individual workloads can be found here in our codebase

Another requirement of our workload implementations is that they produce exactly the same results between the two frameworks. Therefore, we have tried to use the same data loading and preprocessing code as much as possible. Still, there are some challenges because it is more idiomatic for our JAX implementations to run only a single Python process. In contrast, our PyTorch implementations use DistributedDataParallel training with a Python process for each of the 8 GPUs that our benchmark requires.

We’re planning to issue a Call for Submission for our competition around September 1st. However, with the current speed differences between PyTorch and JAX, submissions in PyTorch will not be competitive. People will most likely not submit in PyTorch because the slowdown will significantly reduce their chance at the $50,000 prize pool.

We could really use any help possible from the PyTorch community since we are a small team with only a handful of PyTorch users performing engineering work. We are looking for suggestions on how to speed up our PyTorch implementations and make them competitive with our JAX versions (even better if these suggestions are feasible to implement before we issue the call for submissions). We can also use more extensive engineering contributions or new additions to our working group if anyone wants to join.

Thank you.

1 Like

Hi Frank, thanks for the detailed outline. We will try to profile and dig into whether you are using PyTorch optimally. Might send some PRs your way.

This is not a good answer for the competition itself, but the very simplest thing you should try first is using a PyTorch nightly instead of the PT 2.0 public release. There have been a lot of improvements that haven’t made it out to a public release yet.

Hi Frank, have you tried to run PyTorch models with torch.compile()? Or the comparison is between JAX and pure PyTorch eager mode?

Oh, I saw the call for torch.compile() in

That is amazing! Thanks a lot for digging into it.
Please let us know if we can help you in any way, providing you with more details or logging results.

Yes, we are using torch.compile() whenever we can. As mentioned above, some of the models are currently incompatible with torch.compile(), e.g. due to our use of boolean masking or GraphsTuple as inputs.

We have been running into a number of issues downloading and preparing the data sets. @lessw2020 created a PR for fastMRI => Fastmri dataset setup fixes by lessw2020 · Pull Request #473 · mlcommons/algorithmic-efficiency · GitHub. For speech and wmt, we see sentencepiece segfaulting during the data prep setup step, as per this issue => Data Downloading segfaulting at sentencepiece · Issue #470 · mlcommons/algorithmic-efficiency · GitHub - consequently, we have not been able to repro the torch.compile issue you flagged for wmt.

Beyond the problems at hand, I’m a bit concerned about the inefficiencies that TF data pipelines may introduce to PyTorch as per your description of the hack you used to enable them for PyTorch. Is this putting PyTorch at an unfair disadvantage? Is there a more neutral reading ecosystem, such as Ray Data or Databricks that might be used for both frameworks rather than one that developed for a particular framework flavor?

It would be helpful to know what the absolute wall times were for the benchmarks you ran, so we can cross-reference our measurements with yours. (I didn’t have a V100 ready to go, but on an A100x8 system, I was able to torch.compile and get 7703.403180360794 score in 70 evals; no idea if this is good or not (still building TF haha). EDIT: Hmm, this is in the paper, I guess I need to read it more carefully. EDIT 2: I can’t compare with the paper because I’m on A100x8, but I also ran the JAX baseline and got 7703.041719198227.

I have a question: if I’m comparing the PyTorch baseline vs the JAX baseline, do I expect the number of evals to be the same? (In my case, criteo1tb, they’re not: PyTorch does 70 evals, JAX does 92, but the score ends up being really close in the end anyway?)

Not sure if it is expected for criteo1tb, but it is definitely possible. The score only takes the timing of the training into account and the eval frequency is based on time and not steps, so if the evals are slower in one framework than the other, it is possible that the number of evals differs while the final score is similar.

1 Like

We could consider using a nightly PyTorch version for the competition, as long as we are able to pin to a specific version so that all results are reproducible.
Thanks for the suggestion, we will investigate how the timings change with the nightly version.

Hi everyone, just wanted to provide a macro update on the state of things so far and (not so) briefly give a mental model for torch.compile() that’ll help in squeezing out as much performance as possible

How to fix torch.compile bugs

So you tried to torch.compile() something but it broke what now?

  1. Try the nightlies instead, the project is moving fast. I was using torch 2.1.0.dev20230814+cu118
  2. Look at the error message and open up an issue on GitHub - pytorch/pytorch: Tensors and Dynamic neural networks in Python with strong GPU acceleration
  3. While you’re waiting if you look at the offending line, can you rewrite it to sidestep the issue?
  4. If your error only shows up with DDP then follow instructions here [Inductor] Run compiled model failed on 2023_08_17 nightly · Issue #107362 · pytorch/pytorch · GitHub
  5. Try out backend=aot_eager it’ll help isolate inductor specific bugs but to get perf you essentially get inductor working which brings us to

Ok so now your model compiles but it’s slow how come?

  1. Are you using the default backend="inductor"
  2. Does your model have graph breaks. i.e does it run with fullgraph=True - torch.compile() is an API that progressively unloads complexity onto users but as far as a competition goes getting to 0 graph breaks will be necessary
  3. Do compilation times matter? If not then try mode="max-autotune"

I wish I had more performance numbers to actually share but unless we at the very least get inductor working those numbers will not be particularly meaningful.

Sometimes graph breaks will make sense if for example you notice print statement and it’s your responsibility to remove them to get good perf and other times there’s real bugs. Two noticeable ones I’d like to highlight are

  1. Optimizer support which I’ve tagged Jane on torch.compile default settings · Issue #487 · mlcommons/algorithmic-efficiency · GitHub which popped up in the WMT model
  2. Incorrect tracing which popped up in the OGBG model Dynamo not handling a NamedTuple · Issue #107040 · pytorch/pytorch · GitHub

In both of those cases the fastest way to solve the problem is to open up a github issue while you try out simple 1 liners. I will also attend your meetings regularly to help unblock or triage open issues if there’s a lot of them.

Minimal repros

So when you open up an issue, ideally we want a minimal reproduction and one challenge many of us faced was dealing with getting the preprocessing pipelines to work. I hope you’ll set some simple smoke tests Add dataset setup tests · Issue #486 · mlcommons/algorithmic-efficiency · GitHub since that was the main reason I couldn’t get even more people involved in this effort.

  1. The segfault in affected many models but was fixed Fix segfault in dataset_setup by msaroufim · Pull Request #477 · mlcommons/algorithmic-efficiency · GitHub
  2. Less fixed the fastmri dataset pipeline Fastmri dataset setup fixes by lessw2020 · Pull Request #473 · mlcommons/algorithmic-efficiency · GitHub
  3. Ed found that the criteo pipeline also failed criteo preproc doesn't match input_pipeline · Issue #479 · mlcommons/algorithmic-efficiency · GitHub and could benefit from having more options to run things locally for quick spot checks RuntimeError: shape '[213568]' is invalid for input of size 262144 on criteo1tb · Issue #475 · mlcommons/algorithmic-efficiency · GitHub
  4. I could never get the librispeech pipeline to work and even then it would only fail after 1-2h, I’d like more ways to debug things fast locally and felt like I couldn’t get clever, I had to run the benchmark as is or it would probably fail RuntimeError: shape '[213568]' is invalid for input of size 262144 on criteo1tb · Issue #475 · mlcommons/algorithmic-efficiency · GitHub and OGBG OOM'in in eager mode pytorch · Issue #471 · mlcommons/algorithmic-efficiency · GitHub
  5. Getting an E2E perf number seems like it takes about 24h (at least for criteo), I wish we could have some way of spot checking expected performance as we iterate on the models

One thing we did learn from you is that we should have more end to end data loading examples in our CI as well to catch more errors, In the coming weeks I’ll be looking at what we can upstream from you both in terms of specific models so they don’t regress but also your methodology which others internally have found appealing.

Ok that was not so brief but if I were to spend more time on this here’s how I would track the work per model, make sure that torch.compile() works and then make sure that I’ve removed graph breaks and added mode="max-autotune" this is the strategy we’re taking for models that people care about and it’s a bit time consuming because PyTorch will allow you to write non performant code instead of erroring out.

Task Owner Preprocessing pipeline Made torch.compile work Removed all graph breaks + cudagraph Perf relative to JAX
WMT Mark/Michael Fixed Yes
Librispeech Mark Not fixed No
OGDB Mark Fixed NamedTuple dynamo bug
Criteo Colin/Ed Fixed Yes Matches
Fast-MRI Less Fixed Yes

Per Model updates

None of the models were tested with cuda graphs or fullgraph yet so this will make us look worst vs JAX where things have to be fullgraph


@marksaroufim thank you for the comprehensive update!

On a high level if I understand correctly:

  • OGBG: not fully torch compiled yet so we may see further speed improvements.
  • Criteo: further speed ups possible by refactoring the embedding and not using DDP over embedding.
  • FastMRI: already torch compilable but not checked with fullgraph so we may or may not see further speed improvements.
  • Librispeech deepspeech: works with eager and nightly.
  • WMT transformer: full compilation blocked on AOT autograd issue but fixed nightlies.

The workloads at highest risk from our perspective are Criteo which OOMs and WMT which is 22% slower than jax with current checked in fixes. Regarding further action items:

  • Criteo1tb: It sounds like your team was blocked on running the code due to the data setup issues, so didn’t get to run the workload, correct? I’ll work with @janeyx99 to set them up with a VM on our GCP project to unblock debugging of the OOM issue. Once that is fixed we can work on the embedding refactoring.
  • WMT: Possibly fully torch compilable with nightly, is that correct? In that case I will re-benchmark the timing to get the latest update.
  • OGBG: full torch compile blocked on NamedTuple tracing issue.
  • other workloads : (AI for us) Rerun with torch.nightly and fullgraph to get up to date timing comparisons.

Let me know if that sounds right or if you have any other suggestions.

1 Like
  • Criteo: Yes if you can unblock Jane with a machine that will go a long way
  • WMT: Yes in fact I did try compiling a similar HF model and that worked just fine but most importantly here try to to make sure your implementation compiles, compiles with inductor and then finally has no graph breaks
  • OGBG: Yes, if you like comment here Dynamo not handling a NamedTuple · Issue #107040 · pytorch/pytorch · GitHub to +1 it. If you can also change the code to not use a NamedTuple here instead it might unblock you
  • Other workloads: fullgraph and change the mode to either reduce-overhead or max-autotune and most importantly make sure inductor is being used
1 Like

Is there an index url we can use to pip install the dev package versions?

Having some trouble with setuptools installing nightly.

You can do something like this pip3 install torch==2.1.0.dev20230814+cu118 --index-url

This is the index url

And the selector on the main webpage can show you the right way to install nightlies for different architectures or cuda versions here

1 Like

For people following @Priya_Kasimbeg invited me to their regular engineering syncs - will attend those and make sure we close on any loose ends


Hi Mark,
I upgraded to torch nightly version torch.dev20230820 and reran all of the workloads. There are a couple of new issues with torch compile in the nightly versions on some of the other workloads.

We’d like to pin the pytorch version in approximately 2 weeks and I am not sure which pytorch version to use, because I am not sure how long it will take to resolve the new issues on nightly.

Current state on stable:

Workload Preprocessing pipeline Made torch.compile work Backend Perf Pytorch relative to Jax (t_pytorch - t_jax)/t_jax
criteo1tb Not Fixed No: OOM aot_eager
fastmri Fixed Yes inductor 9
imagenet_resnet Yes inductor 4
imagenet_vit Yes inductor -10
librispeech_conformer Fixed Yes eager -1
librispeech_deepspeech Fixed Yes eager 13
ogbg Fixed No: NamedTuple dynamo bug inductor 11
wmt Fixed Yes inductor 22

Current state on nightly torch.dev08202023:

Workload Preprocessing pipeline Made torch.compile work Backend Perf Pytorch relative to Jax (t_pytorch - t_jax)/t_jax
criteo1tb Not Fixed No: OOM aot_eager blocked
fastmri Fixed Yes inductor todo
imagenet_resnet Yes inductor todo
imagenet_vit No inductor blocked
librispeech_conformer Fixed No: OOM eager blocked
librispeech_deepspeech Fixed No eager blocked
ogbg Fixed No: NamedTuple dynamo bug blocked
wmt Fixed Yes inductor todo

Our highest priorities in terms of the workloads at the moment are to fix the Criteo pytorch workload. On our end we’re working on replacing the Embedding layer but we definitely could use your help still there because it’s strange that it was working on Pytorch 1.0 and we are not sure if changing the Embedding is going to fix it. We’re also working on fixing the data pipeline to unblock you all.
Second highest priority workload is the WMT workload which is 22% slower still on pytorch. It compiles with inductor but we have to use torch stable 2.0.1. to check whether it compiles with fullgraph (pytorch/issues/107362).

1 Like