import os
from statistics import mode
from bleach import clean
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
BACKEND = "nccl"
WORLD_SIZE = 2
NUM_INPUTS = 5
def worker(rank):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29500'
dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)
model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
# Rank 1 gets one more input than rank 0
inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]
num_inputs = 0
# Case 1:
with model.join():
for input in inputs:
num_inputs += 1
loss = model(input)
loss.backward()
# Case 2:
for input in inputs:
num_inputs += 1
loss = model(input)
loss.backward()
# Case 3:
for input in inputs:
num_inputs += 1
loss = model(input)
loss.backward()
dist.barrier()
# Case 4:
with model.join():
for input in inputs:
num_inputs += 1
loss = model(input)
loss.backward()
dist.barrier()
# Case 5:
for input in inputs:
num_inputs += 1
loss = model(input)
loss.backward()
dist.barrier()
# Case 6:
with model.join():
for input in inputs:
num_inputs += 1
loss = model(input).sum()
loss.backward()
dist.barrier()
print(f"Rank {rank} has exhausted all {num_inputs} of its inputs!")
def main():
mp.spawn(worker, nprocs=WORLD_SIZE, join=True)
if __name__ == "__main__":
main()
The basic idea of the code is that I spawn two processes: Rank0 and Rank1 in one Node with two GPUs, which have 5 and 6 batches separately.
And I want to see how the distributed works when meeting uneven data.
When I run those cases separately, those are the results I got:
Case 1: Rank0 and Rank1 are both finished.
Case 2: Rank0 and Rank1 are both finished.
Case 3: Only Rank0 finished and Rank1 is stuck.
Case 4: Only Rank1 finished and Rank0 is stuck.
Case 5: Only Rank0 finished and Rank1 is stuck.
Case 6: Rank0 and Rank1 are both finished.
And some results are counterintuitive for me:
Case 2: I assume Rank1 should get stuck since without the model.join, it should not be able to deal with the extra uneven data, it should just hang there.
Case 3: As I add a barrier compared to Case2, this time it works as I expected. But why do I have to add this barrier? I thought PyTorch will have a synchronization at the loss.backward() line, so the barrier should be redundant.
Case 4: Same as my questions in Case3, I thought adding the dist.barrier() right after loss.backward() won’t have any effect.
Case 5: Given that in Case2, both Rank0 and Rank1 can finish properly, I assume this barrier won’t have any effect again, but it does have.
My Environment
I run this script on Ubuntu 7.5.0-3ubuntu1~18.04 with Pytorch 1.9.1 and Python 3.7.11.
TL;DR: The DDP join context manager uses its own collective communications, so adding dist.barrier() in the training loop can misalign those communications.
Some Details
NCCL barrier() is implemented as an all-reduce of Tensor([X])s, where X represents uninitialized data since it uses empty. X may be 0 or some garbage value — it is runtime dependent.
Not under join():
DDP backward() uses one all-reduce per gradient bucket to synchronize gradients. For the Linear(1, 1) model, there is a single gradient bucket.
DDP forward() schedules an all-reduce of Tensor([1]) to indicate that the calling process has not yet joined. Label this forward() 1.
The join() context manager schedules a matching all-reduce of Tensor([0]) to count the number of processes that have not yet joined. This is via _get_num_nonjoined_procs().
DDP forward() schedules an additional all-reduce of Tensor([1]) to indicate that backward synchronization is required on this iteration. Label this forward() 2.
DDP join() schedules a matching all-reduce of Tensor([0]).
DDP backward() uses one all-reduce per gradient bucket to synchronize gradients.
DDP join() shadows every gradient all-reduce with an all-reduce of zero-valued tensors. This is via the join() hook (specifically, the main hook).
The last joining process broadcasts the final model parameters to the joined processes. However, since there may be multiple last joining processes, an authoritative rank is found via _find_common_rank(), which takes the max rank over the candidates using an all-reduce.
Denote rank 0 and rank 1 by R0 and R1, respectively. R0 has N inputs, and R1 has N + 1 inputs. For simplicity, generally assume N = 2.
Case 1
R0 and R1 finish.
Trace
R0
R1
inputs[0]
all-reduce (DDP forward() 1)
all-reduce (DDP forward() 1)
all-reduce (DDP forward() 2)
all-reduce (DDP forward() 2)
all-reduce (DDP backward())
all-reduce (DDP backward())
inputs[1]
all-reduce (DDP forward() 1)
all-reduce (DDP forward() 1)
all-reduce (DDP forward() 2)
all-reduce (DDP forward() 2)
all-reduce (DDP backward())
all-reduce (DDP backward())
inputs[2]
all-reduce (join() _get_num_nonjoined_procs())
all-reduce (DDP forward() 1)
all-reduce (join() check bwd grad sync)
all-reduce (DDP forward() 2)
all-reduce (join() hook)
all-reduce (DDP backward())
all-reduce (join() _get_num_nonjoined_procs())
all-reduce (join() _get_num_nonjoined_procs())
all-reduce (_find_common_rank())
all-reduce (_find_common_rank())
broadcast (sync final model params)
broadcast (sync final model params)
finish
finish
Case 2
R0 and R1 finish.
Trace
When I tried to reproduce the behavior, this is what I saw: For N = 1, R0 hangs, and R1 finishes. For N > 1, both finish, but the program errors with a NCCL error (Process Group destroyed on rank 1). N = 1:
R0
R1
inputs[0]
all-reduce (backward)
all-reduce (backward)
inputs[1]
finish
all-reduce (backward)
hang
N = 2:
R0
R1
inputs[0]
all-reduce (backward)
all-reduce (backward)
inputs[1]
all-reduce (backward)
all-reduce (backward)
inputs[2]
finish
all-reduce (backward)
finish with error (NCCL error: Process Group destroyed on rank 1)
This N > 1 case is somewhat surprising. I am not entirely sure why R1 still prints, but I imagine that it can be explained with further digging.
Case 3
R0 finishes. R1 hangs.
Trace
R0
R1
inputs[0]
all-reduce (DDP backward())
all-reduce (DDP backward())
all-reduce (barrier)
all-reduce (barrier)
inputs[1]
all-reduce (DDP backward())
all-reduce (DDP backward())
all-reduce (barrier)
all-reduce (barrier)
inputs[2]
finish
all-reduce (DDP backward())
hang
Case 4
R1 finishes. R0 hangs.
Trace
R0
R1
inputs[0]
all-reduce (DDP forward() 1)
all-reduce (DDP forward() 1)
all-reduce (DDP forward() 2)
all-reduce (DDP forward() 2)
all-reduce (DDP backward())
all-reduce (DDP backward())
all-reduce (barrier)
all-reduce (barrier)
inputs[1]
all-reduce (DDP forward() 1)
all-reduce (DDP forward() 1)
all-reduce (DDP forward() 2)
all-reduce (DDP forward() 2)
all-reduce (DDP backward())
all-reduce (DDP backward())
all-reduce (barrier)
all-reduce (barrier)
inputs[2]
all-reduce (join() _get_num_nonjoined_procs())
all-reduce (DDP forward() 1)
all-reduce (join() check bwd grad sync)
all-reduce (DDP forward() 2)
all-reduce (join() hook)
all-reduce (DDP backward())
all-reduce (join() _get_num_nonjoined_procs())
all-reduce (barrier)
all-reduce (join() check bwd grad sync)
all-reduce (join() _get_num_nonjoined_procs())
all-reduce (join() _get_num_nonjoined_procs())
all-reduce (_find_common_rank())
error (resulting from garbage authoratitative_rank value)
Consider the step where R0 corresponds to “check bwd grad sync” and R1 corresponds to _get_num_nonjoined_procs(). (Third row from the bottom)
R0 contributes a 0-valued tensor to check if it needs a backward gradient synchronization.
R1 contributes a 0-valued tensor to check if there are any non-joined processes remaining.
R0 interprets the 0-valued result to mean that no backward gradient synchronization is needed.
R1 interprets the 0-valued result to mean that there are no remaining non-joined processes, so it is the last joining process.
In the next step, R0 calls _get_num_nonjoined_procs(). However, under the current implementation, _get_num_nonjoined_procs() uses the default dtype, float32, while _find_common_rank() uses int64. This results in a cast to floating point in the all-reduce result, and yields a garbage value as the authoritative rank for R1.
Case 5
R0 finishes. R1 hangs.
Trace
R0
R1
inputs[0]
all-reduce (DDP backward())
all-reduce (DDP backward())
inputs[1]
barrier
all-reduce (DDP backward())
finish
barrier
hang
Case 6
R0 and R1 finish.
Trace
This is the same as Case 1 only with an “all-reduce (barrier)” at the end for both R0 and R1.
As you can see, the join context manager introduces several more synchronization points (using all-reduces), so adding extraneous collective communications can lead to strange behavior. Let me know if you have any more questions.