Weird behavior when dealing with uneven inputs using the join context manager

Background

Hi, I was trying to reproduce the tutorial given by https://tutorials.pytorch.kr/advanced/generic_join.html#what-is-join, when I notice some weird behavior.
The code is:

import os
from statistics import mode
from bleach import clean
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP

BACKEND = "nccl"
WORLD_SIZE = 2
NUM_INPUTS = 5

def worker(rank):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)

    model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
    # Rank 1 gets one more input than rank 0
    inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]

    num_inputs = 0

    # Case 1:
    with model.join():
        for input in inputs:
            num_inputs += 1
            loss = model(input)
            loss.backward()

    # Case 2:
    for input in inputs:
        num_inputs += 1
        loss = model(input)
        loss.backward()

    # Case 3:
    for input in inputs:
        num_inputs += 1
        loss = model(input)
        loss.backward()
        dist.barrier()
    
    # Case 4:
    with model.join():
        for input in inputs:
            num_inputs += 1
            loss = model(input)
            loss.backward()
            dist.barrier()
    
    # Case 5:
    for input in inputs:
        num_inputs += 1
        loss = model(input)
        loss.backward()
    dist.barrier()

    # Case 6:
    with model.join():
        for input in inputs:
            num_inputs += 1
            loss = model(input).sum()
            loss.backward()
    dist.barrier()

    print(f"Rank {rank} has exhausted all {num_inputs} of its inputs!")

def main():
    mp.spawn(worker, nprocs=WORLD_SIZE, join=True)

if __name__ == "__main__":
    main()

The basic idea of the code is that I spawn two processes: Rank0 and Rank1 in one Node with two GPUs, which have 5 and 6 batches separately.
And I want to see how the distributed works when meeting uneven data.

When I run those cases separately, those are the results I got:

  • Case 1: Rank0 and Rank1 are both finished.
  • Case 2: Rank0 and Rank1 are both finished.
  • Case 3: Only Rank0 finished and Rank1 is stuck.
  • Case 4: Only Rank1 finished and Rank0 is stuck.
  • Case 5: Only Rank0 finished and Rank1 is stuck.
  • Case 6: Rank0 and Rank1 are both finished.

And some results are counterintuitive for me:

  • Case 2: I assume Rank1 should get stuck since without the model.join, it should not be able to deal with the extra uneven data, it should just hang there.
  • Case 3: As I add a barrier compared to Case2, this time it works as I expected. But why do I have to add this barrier? I thought PyTorch will have a synchronization at the loss.backward() line, so the barrier should be redundant.
  • Case 4: Same as my questions in Case3, I thought adding the dist.barrier() right after loss.backward() won’t have any effect.
  • Case 5: Given that in Case2, both Rank0 and Rank1 can finish properly, I assume this barrier won’t have any effect again, but it does have.

My Environment

I run this script on Ubuntu 7.5.0-3ubuntu1~18.04 with Pytorch 1.9.1 and Python 3.7.11.

1 Like

Thanks for trying out the join context manager!

TL;DR: The DDP join context manager uses its own collective communications, so adding dist.barrier() in the training loop can misalign those communications.

Some Details

  • NCCL barrier() is implemented as an all-reduce of Tensor([X])s, where X represents uninitialized data since it uses empty. X may be 0 or some garbage value — it is runtime dependent.

Not under join():

  • DDP backward() uses one all-reduce per gradient bucket to synchronize gradients. For the Linear(1, 1) model, there is a single gradient bucket.

Under join() (assuming not under no_sync()):

  • DDP forward() schedules an all-reduce of Tensor([1]) to indicate that the calling process has not yet joined. Label this forward() 1.
  • The join() context manager schedules a matching all-reduce of Tensor([0]) to count the number of processes that have not yet joined. This is via _get_num_nonjoined_procs().
  • DDP forward() schedules an additional all-reduce of Tensor([1]) to indicate that backward synchronization is required on this iteration. Label this forward() 2.
  • DDP join() schedules a matching all-reduce of Tensor([0]).
  • DDP backward() uses one all-reduce per gradient bucket to synchronize gradients.
  • DDP join() shadows every gradient all-reduce with an all-reduce of zero-valued tensors. This is via the join() hook (specifically, the main hook).
  • The last joining process broadcasts the final model parameters to the joined processes. However, since there may be multiple last joining processes, an authoritative rank is found via _find_common_rank(), which takes the max rank over the candidates using an all-reduce.

Denote rank 0 and rank 1 by R0 and R1, respectively. R0 has N inputs, and R1 has N + 1 inputs. For simplicity, generally assume N = 2.

Case 1

  • R0 and R1 finish.
Trace
R0 R1
inputs[0] all-reduce (DDP forward() 1) all-reduce (DDP forward() 1)
all-reduce (DDP forward() 2) all-reduce (DDP forward() 2)
all-reduce (DDP backward()) all-reduce (DDP backward())
inputs[1] all-reduce (DDP forward() 1) all-reduce (DDP forward() 1)
all-reduce (DDP forward() 2) all-reduce (DDP forward() 2)
all-reduce (DDP backward()) all-reduce (DDP backward())
inputs[2] all-reduce (join() _get_num_nonjoined_procs()) all-reduce (DDP forward() 1)
all-reduce (join() check bwd grad sync) all-reduce (DDP forward() 2)
all-reduce (join() hook) all-reduce (DDP backward())
all-reduce (join() _get_num_nonjoined_procs()) all-reduce (join() _get_num_nonjoined_procs())
all-reduce (_find_common_rank()) all-reduce (_find_common_rank())
broadcast (sync final model params) broadcast (sync final model params)
finish finish

Case 2

  • R0 and R1 finish.
Trace

When I tried to reproduce the behavior, this is what I saw: For N = 1, R0 hangs, and R1 finishes. For N > 1, both finish, but the program errors with a NCCL error (Process Group destroyed on rank 1).
N = 1:

R0 R1
inputs[0] all-reduce (backward) all-reduce (backward)
inputs[1] finish all-reduce (backward)
hang

N = 2:

R0 R1
inputs[0] all-reduce (backward) all-reduce (backward)
inputs[1] all-reduce (backward) all-reduce (backward)
inputs[2] finish all-reduce (backward)
finish with error (NCCL error: Process Group destroyed on rank 1)

This N > 1 case is somewhat surprising. I am not entirely sure why R1 still prints, but I imagine that it can be explained with further digging.

Case 3

  • R0 finishes. R1 hangs.
Trace
R0 R1
inputs[0] all-reduce (DDP backward()) all-reduce (DDP backward())
all-reduce (barrier) all-reduce (barrier)
inputs[1] all-reduce (DDP backward()) all-reduce (DDP backward())
all-reduce (barrier) all-reduce (barrier)
inputs[2] finish all-reduce (DDP backward())
hang

Case 4

  • R1 finishes. R0 hangs.
Trace
R0 R1
inputs[0] all-reduce (DDP forward() 1) all-reduce (DDP forward() 1)
all-reduce (DDP forward() 2) all-reduce (DDP forward() 2)
all-reduce (DDP backward()) all-reduce (DDP backward())
all-reduce (barrier) all-reduce (barrier)
inputs[1] all-reduce (DDP forward() 1) all-reduce (DDP forward() 1)
all-reduce (DDP forward() 2) all-reduce (DDP forward() 2)
all-reduce (DDP backward()) all-reduce (DDP backward())
all-reduce (barrier) all-reduce (barrier)
inputs[2] all-reduce (join() _get_num_nonjoined_procs()) all-reduce (DDP forward() 1)
all-reduce (join() check bwd grad sync) all-reduce (DDP forward() 2)
all-reduce (join() hook) all-reduce (DDP backward())
all-reduce (join() _get_num_nonjoined_procs()) all-reduce (barrier)
all-reduce (join() check bwd grad sync) all-reduce (join() _get_num_nonjoined_procs())
all-reduce (join() _get_num_nonjoined_procs()) all-reduce (_find_common_rank())
error (resulting from garbage authoratitative_rank value)

Consider the step where R0 corresponds to “check bwd grad sync” and R1 corresponds to _get_num_nonjoined_procs(). (Third row from the bottom)

  • R0 contributes a 0-valued tensor to check if it needs a backward gradient synchronization.
  • R1 contributes a 0-valued tensor to check if there are any non-joined processes remaining.
  • R0 interprets the 0-valued result to mean that no backward gradient synchronization is needed.
  • R1 interprets the 0-valued result to mean that there are no remaining non-joined processes, so it is the last joining process.
  • In the next step, R0 calls _get_num_nonjoined_procs(). However, under the current implementation, _get_num_nonjoined_procs() uses the default dtype, float32, while _find_common_rank() uses int64. This results in a cast to floating point in the all-reduce result, and yields a garbage value as the authoritative rank for R1.

Case 5

  • R0 finishes. R1 hangs.
Trace
R0 R1
inputs[0] all-reduce (DDP backward()) all-reduce (DDP backward())
inputs[1] barrier all-reduce (DDP backward())
finish barrier
hang

Case 6

  • R0 and R1 finish.
Trace

This is the same as Case 1 only with an “all-reduce (barrier)” at the end for both R0 and R1.

As you can see, the join context manager introduces several more synchronization points (using all-reduces), so adding extraneous collective communications can lead to strange behavior. Let me know if you have any more questions.

4 Likes

Thank you soooo much for your detailed reply and illustration, I found them very informative!