Thread deadlock in DataParallel

I
I am facing a thread deadlock issue when I use multiple GPUs with DataParallel(). The model is training on a medium-size dataset with 240K training samples. The model successfully trains for one epoch. In the second epoch, the training progresses smoothly till it reaches 50%. After that, it is simply stuck with no progress. When I kill the process using ctrl+c or kill -s SIGKILL, it becomes a zombie process!

Here is what I get when I do ctrl+c

File "run_E2E_EL_RE.py", line 962, in <module>
    main()
 File "run_E2E_EL_RE.py", line 913, in main
    global_step, tr_loss = train(args, model, tokenizer)
 File "run_E2E_EL_RE.py", line 249, in train
    el_loss, re_loss, _, _ = model.forward(**ned_inputs)
  File "/dresden/users/rb897/anaconda3/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 152, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "/dresden/users/rb897/anaconda3/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 162, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/dresden/users/rb897/anaconda3/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 77, in parallel_apply
    thread.join()
  File "/dresden/users/rb897/anaconda3/lib/python3.7/threading.py", line 1044, in join
    self._wait_for_tstate_lock()
  File "/dresden/users/rb897/anaconda3/lib/python3.7/threading.py", line 1060, in _wait_for_tstate_lock
   elif lock.acquire(block, timeout):
KeyboardInterrupt

Code Snippet:

model.zero_grad()
train_iterator = trange(
        epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]
    )
set_seed(args)  # Added here for reproducibility
for epoch_num in train_iterator:
    epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
    for step, batch in enumerate(epoch_iterator):
        model.train()
        batch = tuple(t.to(args.device) for t in batch)
        inputs_1 = {...}
        inputs_2 = {...}
        loss_1, last_hidden_states = model.forward(**inputs_1)
        inputs_2["last_hidden_states"] = last_hidden_states
        loss_2, loss_3 = model.forward(**inputs_2)
        if args.n_gpu > 1: # mean() to average on multi-gpu parallel training
            loss_1 = loss_1.mean()
            loss_2 = loss_2.mean()
            loss_3 = loss_3.mean()
        loss = loss_1 + loss_2 + loss_3
        loss.backward()
        optimizer.step()
        scheduler.step() 
        model.zero_grad()

per gpu batch size = 4

OS: Ubuntu 18.04.5
CPU: Intel Xeon® - 64 cores
CUDA Version: 11.2
GPU: Quadro RTX 6000
GPU memory: 24G
PyTorch: 1.4.0

Do you consistently see it stuck in the second epoch? Does your job exclusively occupy all the GPUs?

Besides what @mrshenli asked: are you also seeing this issue using the latest PyTorch stable or nightly release?

Yes, I performed three runs and in each case it got stuck at the same point. I am using 4 out of 8 available GPUs and no other jobs are running on those GPUs.

I reran the code using the latest PyTorch. I am seeing the following error message. This time it happens with the first batch of the first epoch.

Traceback (most recent call last):
  File "run_E2E_EL_RE.py", line 962, in <module>
    main()
  File "run_E2E_EL_RE.py", line 913, in main
    global_step, tr_loss = train(args, model, tokenizer)
  File "run_E2E_EL_RE.py", line 246, in train
    ner_loss, last_hidden_states = model.forward(**ner_inputs)
  File "/home/rajarshi_kingsaint_bhowmik/anaconda3/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 167, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "/home/rajarshi_kingsaint_bhowmik/anaconda3/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 177, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/home/rajarshi_kingsaint_bhowmik/anaconda3/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 86, in parallel_apply
    output.reraise()
  File "/home/rajarshi_kingsaint_bhowmik/anaconda3/lib/python3.8/site-packages/torch/_utils.py", line 429, in reraise
    raise self.exc_type(msg)
StopIteration: Caught StopIteration in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/home/rajarshi_kingsaint_bhowmik/anaconda3/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
    output = module(*input, **kwargs)
  File "/home/rajarshi_kingsaint_bhowmik/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/rajarshi_kingsaint_bhowmik/E2E-EL-RE/modeling_E2E_EL_RE.py", line 65, in forward
    mention_outputs = self.bert_mention.bert(
  File "/home/rajarshi_kingsaint_bhowmik/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/rajarshi_kingsaint_bhowmik/E2E-EL-RE/modeling_bert.py", line 758, in forward
    extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype)  # fp16 compatibility
StopIteration