I’m currently looking into using torch.compile()
. Everything works great, however when I add a scheduler.step() at the end of a compiled training step (I update the LR per batch training step), I’m getting warnings (same for each rank):
After the first 12 steps:
torch._dynamo.convert_frame: [WARNING] torch._dynamo hit config.cache_size_limit (8)
torch._dynamo.convert_frame: [WARNING] function: ‘step’ (/root/.virtualenvs/mlpug/lib/python3.10/site-packages/torch/optim/adamw.py:151)
torch._dynamo.convert_frame: [WARNING] last reason: L[‘self’].param_groups[0][‘lr’] == 4.8361308593749987e-05 # has_complex = self._init_group( # optim/adamw.py:176 in step
torch._dynamo.convert_frame: [WARNING] To log all recompilation reasons, use TORCH_LOGS=“recompiles”.
torch._dynamo.convert_frame: [WARNING] To diagnose recompilation issues, see PyTorch 2.0 Troubleshooting — PyTorch master documentation.
After step 16:
…
torch._dynamo.convert_frame: [WARNING] function: ‘adamw’ (/root/.virtualenvs/mlpug/lib/python3.10/site-packages/torch/optim/adamw.py:279)
torch._dynamo.convert_frame: [WARNING] last reason: L[‘lr’] == 3.208388020496905e-05 # if foreach and isinstance(lr, Tensor) and not capturable: # optim/adamw.py:320 in adamw
…
After step 20:
…
torch._dynamo.convert_frame: [WARNING] function: ‘_multi_tensor_adamw’ (/root/.virtualenvs/mlpug/lib/python3.10/site-packages/torch/optim/adamw.py:479)
torch._dynamo.convert_frame: [WARNING] last reason: L[‘lr’] == 2.6132520951198568e-05 # if isinstance(lr, Tensor) and not capturable: # optim/adamw.py:503 in _multi_tensor_adamw
…
After that no more warnings … but the training step is 3-4 times as slow.
Have this been observed before when using scheduler.step() per (compiled) batch training step? Is there a specific way this can be resolved?
Thanks,
– Freddy