Hi, has anyone seen this error in the last 3 months or so running a compiled pytorch model with multi-GPU?
[rank0]: File "/usr/lib/python3/dist-packages/torch/_dynamo/variables/builder.py", line 529, in _wrap
[rank0]: if has_triton():
[rank0]: File "/usr/lib/python3/dist-packages/torch/utils/_triton.py", line 37, in has_triton
[rank0]: return is_device_compatible_with_triton() and has_triton_package()
[rank0]: File "/usr/lib/python3/dist-packages/torch/utils/_triton.py", line 33, in is_device_compatible_with_triton
[rank0]: if device_interface.is_available() and extra_check(device_interface):
[rank0]: File "/usr/lib/python3/dist-packages/torch/utils/_triton.py", line 23, in cuda_extra_check
[rank0]: return device_interface.Worker.get_device_properties().major >= 7
[rank0]: File "/usr/lib/python3/dist-packages/torch/_dynamo/device_interface.py", line 191, in get_device_properties
[rank0]: return caching_worker_device_properties["cuda"][device]
[rank0]: torch._dynamo.exc.InternalTorchDynamoError: IndexError: list index out of range
[rank0]: from user code:
[rank0]: File "/usr/lib/python3/dist-packages/torch/_dynamo/external_utils.py", line 40, in inner
[rank0]: return fn(*args, **kwargs)
[rank0]: Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
The same code was running fine last October-November on the same type of instance and still runs fine last week on my local machine. I have tried changing the version of PyTorch (2.4.1, 2.5.1, 2.6.0) and triton (3.2.0, 3.1.0, 3.0.0, and the last 2.x) but none of that helped. torch.cuda.is_available()
and torch.cuda.device_count()
look fine and simple test code (e.g. MHA w/ mask on DDP Ā· GitHub) also works fine. What other libraries/drivers might contribute to this?
The error is probably very environment specific but the exact steps to reproduce on a Lambda 8x A100 (40 GB SXM4) instance is as follows:
pip3 install --upgrade requests
pip3 install wandb
pip3 install schedulefree
git clone https://github.com/EIFY/mup-vit.git
cd mup-vit
NUMEXPR_MAX_THREADS=116 torchrun main.py --fake-data --batch-size 4096 --log-steps 100