I have a similar problem. I used DistributedDataParallel and python -m torch.distributed.launch --nproc_per_node=8.
DistributedDataParallel
python -m torch.distributed.launch --nproc_per_node=8