Torch._dynamo hit config.cache_size_limit when adding LR scheduler

I’m currently looking into using torch.compile(). Everything works great, however when I add a scheduler.step() at the end of a compiled training step (I update the LR per batch training step), I’m getting warnings (same for each rank):

After the first 12 steps:

torch._dynamo.convert_frame: [WARNING] torch._dynamo hit config.cache_size_limit (8)
torch._dynamo.convert_frame: [WARNING] function: ‘step’ (/root/.virtualenvs/mlpug/lib/python3.10/site-packages/torch/optim/adamw.py:151)
torch._dynamo.convert_frame: [WARNING] last reason: L[‘self’].param_groups[0][‘lr’] == 4.8361308593749987e-05 # has_complex = self._init_group( # optim/adamw.py:176 in step
torch._dynamo.convert_frame: [WARNING] To log all recompilation reasons, use TORCH_LOGS=“recompiles”.
torch._dynamo.convert_frame: [WARNING] To diagnose recompilation issues, see PyTorch 2.0 Troubleshooting — PyTorch master documentation.

After step 16:


torch._dynamo.convert_frame: [WARNING] function: ‘adamw’ (/root/.virtualenvs/mlpug/lib/python3.10/site-packages/torch/optim/adamw.py:279)
torch._dynamo.convert_frame: [WARNING] last reason: L[‘lr’] == 3.208388020496905e-05 # if foreach and isinstance(lr, Tensor) and not capturable: # optim/adamw.py:320 in adamw

After step 20:


torch._dynamo.convert_frame: [WARNING] function: ‘_multi_tensor_adamw’ (/root/.virtualenvs/mlpug/lib/python3.10/site-packages/torch/optim/adamw.py:479)
torch._dynamo.convert_frame: [WARNING] last reason: L[‘lr’] == 2.6132520951198568e-05 # if isinstance(lr, Tensor) and not capturable: # optim/adamw.py:503 in _multi_tensor_adamw

After that no more warnings … but the training step is 3-4 times as slow.

Have this been observed before when using scheduler.step() per (compiled) batch training step? Is there a specific way this can be resolved?

Thanks,
– Freddy

PS1: It seems to me that the training step is recompiled everytime the LR is changed. However, that doesn’t make much sense to me because it doesn’t change the graph or the size of tensors in any way. It would be great if someone can explain this.

PS2: When setting:

torch._dynamo.config.cache_size_limit = 64

The training steps are slow, implying that the step function is recompiled with every call.
Further, the warnings now start after step 69, which makes sense now that the cache size is much bigger.

I gathered some batch training times at a few (semi-random) steps with and without LR scheduling per batch, see the table below. In both cases the training step time converge to the same duration, but the training steps with LR scheduling need much more time to converge (a lot of recompilation going on I guess).

--------------------------------------------------------------------------
| Epoch | Batch step | Batch training time     | Batch training time     |
|       |            | (NO LR scheduler)       | (With scheduler.step()) |
--------------------------------------------------------------------------
|   0   |    10/976  |        12937ms          |     21789ms             |
|   0   |    20/976  |         7011ms          |     26957ms             |
|   0   |   300/976  |         1718ms          |      3271ms             |
|   0   |   900/976  |         1474ms          |      1994ms             |
|   2   |   430/976  |         1338ms          |      1339ms             |
--------------------------------------------------------------------------

@ptrblck @marksaroufim do you have any idea why a torch.optim.lr_scheduler scheduler could cause so much havoc for toch.compile? It is only changing the lr per parameter group …

Next would maybe be to log all compilation reasons (TORCH_LOGS=“recompiles”), but I can imagine that you already came across this issue; LR scheduling at the batch training level is normal for LLM training where we don’t train for many epochs.

Which PyTorch version are you using? IIRC latest released fixed some issues with LR schedulers, could you please try nightlies? If that doesn’t’ work then it’d be worth opening an issue on Github and tagging me since this is important

Hi @marksaroufim, thanks for your reply.

In my previous message I used the following PyTorch version:

$ pip show torch
Name: torch
Version: 2.2.1
Summary: Tensors and Dynamic neural networks in Python with strong GPU acceleration
Home-page: https://pytorch.org/
Author: PyTorch Team
Author-email: packages@pytorch.org
License: BSD-3
Location: /root/.virtualenvs/mlpug/lib/python3.10/site-packages
Requires: filelock, fsspec, jinja2, networkx, nvidia-cublas-cu12, nvidia-cuda-cupti-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-runtime-cu12, nvidia-cudnn-cu12, nvidia-cufft-cu12, nvidia-curand-cu12, nvidia-cusolver-cu12, nvidia-cusparse-cu12, nvidia-nccl-cu12, nvidia-nvtx-cu12, sympy, triton, typing-extensions
Required-by:

I tried the same experiment with Nightly and found that the performance is even worse, it takes even longer for the duration to converge, however the final batch step training duration is slightly better than before.

$ pip show torch
Name: torch
Version: 2.3.0.dev20240228+cu121
Summary: Tensors and Dynamic neural networks in Python with strong GPU acceleration
Home-page: https://pytorch.org/
Author: PyTorch Team
Author-email: packages@pytorch.org
License: BSD-3
Location: /root/.virtualenvs/mlpug/lib/python3.10/site-packages
Requires: filelock, fsspec, jinja2, networkx, nvidia-cublas-cu12, nvidia-cuda-cupti-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-runtime-cu12, nvidia-cudnn-cu12, nvidia-cufft-cu12, nvidia-curand-cu12, nvidia-cusolver-cu12, nvidia-cusparse-cu12, nvidia-nccl-cu12, nvidia-nvtx-cu12, pytorch-triton, sympy, typing-extensions
Required-by: torchaudio, torchvision

----------------------------------------------------------------------------------------------------
| Epoch | Batch step | Batch training time     | Batch training time     | Batch training time     |
|       |            | (NO LR scheduler)       | (With scheduler.step()) | (With scheduler.step()) |
|       |            | (PyTorch 2.2.1, Cu12)   | (PyTorch 2.2.1, Cu12)   | (PyTorch Nightly, Cu121)|
----------------------------------------------------------------------------------------------------
|   0   |    10/976  |         12937ms         |      21789ms            |      23120ms            |
|   0   |    20/976  |          7011ms         |      26957ms            |      28752ms            |
|   0   |   300/976  |          1718ms         |       3271ms            |       3419ms            |
|   0   |   900/976  |          1474ms         |       1994ms            |       2036ms            |
|   2   |   430/976  |          1338ms         |       1339ms            |       1328ms            |
----------------------------------------------------------------------------------------------------

I will open an issue on Github and tag you there.

PyTorch Github issue filed here.