*deadlock* when using torch.distributed.broadcast

Hey there,

We’re trying to add an early-stopping condition to our distributed training loop, where the rank0 process monitors the training loss, and then broadcasts a KILL value to the other processes to stop training.

Unfortunately our code deadlocks when broadcasting values (and the GPU utilization spikes to 100% on both cards).

I’ve provided a minimal version of the code below, along with the output. It’s unclear why the code is freezing, and why the GPU utilization is spiking to 100% on both cards. The configuration details are attached in the output below.

Thanks!

import os
import argparse
import random
import torch
import torch.distributed as dist
from logger import get_logger

log = get_logger(__name__)


def set_seeds(seed: float):
	random.seed(seed)
	os.environ['PYTHONHASHSEED'] = str(seed)
	torch.manual_seed(seed)
	if torch.cuda.is_available():
		torch.cuda.manual_seed(seed)
		torch.cuda.manual_seed_all(seed)


def get_cuda_info(device):
	return {
		"PyTorch_version": torch.__version__,
		"CUDA_version": torch.version.cuda,
		"cuDNN_version": torch.backends.cudnn.version(),
		"Arch_version": torch._C._cuda_getArchFlags(),
		"device_count": torch.cuda.device_count(),
		"device_name": torch.cuda.get_device_name(device=device),
		"device_id": torch.cuda.current_device(),
		"cuda_capability": torch.cuda.get_device_capability(device=device),
		"device_properties": torch.cuda.get_device_properties(device=device),
		"reserved_memory": torch.cuda.memory_reserved(device=device) / 1024**3,
		"allocated_memory": torch.cuda.memory_allocated(device=device) / 1024**3,
		"max_allocated_memory": torch.cuda.max_memory_allocated(device=device) / 1024**3
	}


def argparser():
	parser = argparse.ArgumentParser(
		description=__doc__,
		formatter_class=argparse.ArgumentDefaultsHelpFormatter
	)
	parser.add_argument(
		"--seed",
		default=42,
		type=int
	)
	parser.add_argument(
		"--use-cuda",
		action='store_true',
		default=torch.cuda.is_available(),
	)
	parser.add_argument(
		"--cudnn-deterministic",
		action='store_true',
		default=True
	)
	parser.add_argument(
		"--cudnn-benchmark",
		action='store_true',
		default=False
	)
	parser.add_argument(
		"--number-gpus",
		default=torch.cuda.device_count(),
		type=int,
	)
	parser.add_argument(
		"--local_rank",
		type=int,
		default=0,
	)
	parser.add_argument(
		"--nodes",
		default=1,
		type=int,
	)

	args = parser.parse_args()  # args=[]
	args.world_size = args.number_gpus * args.nodes
	return args


def build_model(args):
	set_seeds(seed=args.seed)

	if args.world_size > 1:
		os.environ["OMP_NUM_THREADS"] = "1"
		os.environ["PYTHONHASHSEED"] = str(args.seed)
		os.environ['MASTER_ADDR'] = '127.0.0.1'
		os.environ['MASTER_PORT'] = '29500'

		torch.cuda.set_device(args.local_rank)
		dist.init_process_group(
			backend=dist.Backend.NCCL,
			init_method='env://',
			world_size=args.world_size,
			rank=args.local_rank
		)
	else:
		torch.cuda.set_device(args.local_rank)

	count = torch.zeros(1).cuda(args.local_rank)
	for index in range(10):
		log.info(f'GPU:{args.local_rank} running new iteration of the loop @ index:{index} - count:{count}')

		if (args.world_size > 1) and (args.local_rank == 0):
			count.add_(1)
			dist.broadcast(count, src=args.local_rank)
			log.info(f'BROADCASTING count {count} FROM GPU:{args.local_rank}')

		log.info(f'GPU:{args.local_rank} @ barrier')
		dist.barrier()

	log.info(f'GPU:{args.local_rank} finished!')
	dist.destroy_process_group()


def main():
	args = argparser()
	torch.set_printoptions(precision=10)
	if args.use_cuda:
		torch.backends.cudnn.enabled = True
		torch.backends.cudnn.deterministic = args.cudnn_deterministic
		torch.backends.cudnn.benchmark = args.cudnn_benchmark
		torch.cuda.empty_cache()
		log.info(get_cuda_info(device=args.local_rank))
	else:
		Exception('GPU(s) required')

	log.info(args)
	build_model(args)


if __name__ == '__main__':
	try:
		__IPYTHON__
		print('\nrunning via ipython -> not running continously')
	except NameError:
		main()

The program can be run as follows:

CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 --node_rank=0 test_ddp.py

Here is the output from nvidia-smi

Every 2.0s: nvidia-smi                                                                                                                                                                                                                                                                                                                   metricle0: Tue May 31 13:59:43 2022

Tue May 31 13:59:43 2022
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.73.05    Driver Version: 510.73.05    CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Quadro RTX 6000     Off  | 00000000:67:00.0 Off |                  Off |
| 33%   58C    P2   112W / 260W |   1378MiB / 24576MiB |    100%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  Quadro RTX 6000     Off  | 00000000:68:00.0 Off |                  Off |
| 34%   64C    P2   110W / 260W |   1378MiB / 24576MiB |    100%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A    801664      C   /opt/conda/bin/python            1375MiB |
|    1   N/A  N/A    801665      C   /opt/conda/bin/python            1375MiB |
+-----------------------------------------------------------------------------+

and here is the console output:

root@dd8aa307408a:~/src/tests# python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 --node_rank=0 test_ddp.py
/opt/conda/lib/python3.9/site-packages/torch/distributed/launch.py:178: FutureWarning: The module torch.distributed.launch is deprecated
and will be removed in future. Use torchrun.
Note that --use_env is set by default in torchrun.
If your script expects `--local_rank` argument to be set, please
change it to read from `os.environ['LOCAL_RANK']` instead. See
https://pytorch.org/docs/stable/distributed.html#launch-utility for
further instructions

  warnings.warn(
WARNING:torch.distributed.run:
*****************************************
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
*****************************************
887 - 2022-05-31 13:59:32,740 - __main__ - INFO - {'PyTorch_version': '1.11.0+cu113', 'CUDA_version': '11.3', 'cuDNN_version': 8200, 'Arch_version': 'sm_37 sm_50 sm_60 sm_70 sm_75 sm_80 sm_86', 'device_count': 2, 'device_name': 'Quadro RTX 6000', 'device_id': 0, 'cuda_capability': (7, 5), 'device_properties': _CudaDeviceProperties(name='Quadro RTX 6000', major=7
, minor=5, total_memory=24217MB, multi_processor_count=72), 'reserved_memory': 0.0, 'allocated_memory': 0.0, 'max_allocated_memory': 0.0}
887 - 2022-05-31 13:59:32,740 - __main__ - INFO - Namespace(seed=42, use_cuda=True, cudnn_deterministic=True, cudnn_benchmark=False, number_gpus=2, local_rank=1, nodes=1, world_size=2)
887 - 2022-05-31 13:59:32,741 - torch.distributed.distributed_c10d - INFO - Added key: store_based_barrier_key:1 to store for rank: 1
886 - 2022-05-31 13:59:32,743 - __main__ - INFO - {'PyTorch_version': '1.11.0+cu113', 'CUDA_version': '11.3', 'cuDNN_version': 8200, 'Arch_version': 'sm_37 sm_50 sm_60 sm_70 sm_75 sm_80 sm_86', 'device_count': 2, 'device_name': 'Quadro RTX 6000', 'device_id': 0, 'cuda_capability': (7, 5), 'device_properties': _CudaDeviceProperties(name='Quadro RTX 6000', major=7
, minor=5, total_memory=24220MB, multi_processor_count=72), 'reserved_memory': 0.0, 'allocated_memory': 0.0, 'max_allocated_memory': 0.0}
886 - 2022-05-31 13:59:32,743 - __main__ - INFO - Namespace(seed=42, use_cuda=True, cudnn_deterministic=True, cudnn_benchmark=False, number_gpus=2, local_rank=0, nodes=1, world_size=2)
886 - 2022-05-31 13:59:32,744 - torch.distributed.distributed_c10d - INFO - Added key: store_based_barrier_key:1 to store for rank: 0
886 - 2022-05-31 13:59:32,745 - torch.distributed.distributed_c10d - INFO - Rank 0: Completed store-based barrier for key:store_based_barrier_key:1 with 2 nodes.
887 - 2022-05-31 13:59:32,752 - torch.distributed.distributed_c10d - INFO - Rank 1: Completed store-based barrier for key:store_based_barrier_key:1 with 2 nodes.
887 - 2022-05-31 13:59:35,586 - __main__ - INFO - GPU:1 running new iteration of the loop @ index:0 - count:tensor([0.], device='cuda:1')
887 - 2022-05-31 13:59:35,586 - __main__ - INFO - GPU:1 @ barrier
886 - 2022-05-31 13:59:35,592 - __main__ - INFO - GPU:0 running new iteration of the loop @ index:0 - count:tensor([0.], device='cuda:0')
886 - 2022-05-31 13:59:35,639 - __main__ - INFO - BROADCASTING count tensor([1.], device='cuda:0') FROM GPU:0
886 - 2022-05-31 13:59:35,639 - __main__ - INFO - GPU:0 @ barrier
886 - 2022-05-31 13:59:35,639 - __main__ - INFO - GPU:0 running new iteration of the loop @ index:1 - count:tensor([1.], device='cuda:0')
886 - 2022-05-31 13:59:35,640 - __main__ - INFO - BROADCASTING count tensor([2.], device='cuda:0') FROM GPU:0
886 - 2022-05-31 13:59:35,640 - __main__ - INFO - GPU:0 @ barrier

The program runs to completion if you comment out the dist.broadcast() call. It’s also not clear why GPU1 is getting stuck at the barrier?

Thanks for posting the question @vgoklani Looking into your script, it seems like you are calling broadcast only on rank 0, rank 1 is expecting to call the same broadcast collective but it does not, which is why it hangs.

In general, the distributed collectives like broadcast should be called on all ranks in a SPMD (single program, multiple data) manner so that all ranks know it reach a certain point that they can proceed to the next instruction, except for point-to-point communications like send/recv.

3 Likes