Knowing if training on multi-GPU

Hi. I have the model that I sometimes train on a single GPU and sometimes on multiples GPUs. In its architecture there are batchnorm’s so I decided to use model = SyncBatchNorm.convert_sync_batchnorm(model) to avoid performance loss when training on Multi-GPU.

However, the thing is that now when I want to train on one GPU I get the following error:

...
  File "/home/alain/miniconda3/envs/ts/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/alain/miniconda3/envs/ts/lib/python3.10/site-packages/torch/nn/modules/container.py", line 141, in forward
    input = module(input)
  File "/home/alain/miniconda3/envs/ts/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/alain/miniconda3/envs/ts/lib/python3.10/site-packages/torch/nn/modules/batchnorm.py", line 731, in forward
    world_size = torch.distributed.get_world_size(process_group)
  File "/home/alain/miniconda3/envs/ts/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 867, in get_world_size
    return _get_group_size(group)
  File "/home/alain/miniconda3/envs/ts/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 325, in _get_group_size
    default_pg = _get_default_group()
  File "/home/alain/miniconda3/envs/ts/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 429, in _get_default_group
    raise RuntimeError(
RuntimeError: Default process group has not been initialized, please make sure to call init_process_group.

I tried to do something conditional like:

import torch.distributed as dist
model = SyncBatchNorm.convert_sync_batchnorm(model) if dist.get_world_size() > 1 else model

but I got a similar error when trying to evaluate dist.get_world_size().

Also, I use Pytorch Lightning, I don’t know whether it can help/be an issue.

Does anyone has an idea?
Thanks a lot!

Try to use torch.distributed.is_initialized() in a condition.