Multiprocessing - Barrier Blocks all Processes?

I tried to do that but it doesnt work. You said that other proicesses would see signal.item() < world_size, but the first process shows the signal to be 0. So it never blocks on all_reduce. Can you tell me where im going wrong?

    if self.id == 0:
        ep = 10
        sl = 0.1
    elif self.id == 1:
        ep = 15
        sl = 0.2

    for i in range(0, ep):

        # if i == 10:
        #     break

        signal = torch.tensor([1]).to(self.device)
        work = dist.all_reduce(signal, async_op=True)

        #self._log.info(self.id, "image.. " + str(i))
        image = torch.zeros((1,2,3,256,384)).cuda(self.id)
        input = {"rgb": image}
        output = self.model(input)
        loss = torch.sum(output["output"][-1])
        
        work.wait()
        self._log.info(self.id, "SIG: " + str(signal.item()))
        if signal.item() < self.cfg.mp.workers:
            self._log.info(self.id, "EXIT: " + str(signal.item()))
            break

        self.optimizer.zero_grad()
        #self._log.info(self.id, "backward.. " + str(i))
        loss.backward()
        self.optimizer.step()
        #print("step: " + str(self.id) + " " + str(i))
        self._log.info(self.id, "done " + str(i))

        time.sleep(sl)

    self._log.info(self.id, "reduce.. ")
    signal = torch.tensor([0]).to(self.device)
    self._log.info(self.id, "FIANL SIGNAL: " + str(signal.item()))
    if signal.item() >= self.cfg.mp.workers:
        dist.all_reduce(signal) # NOT BLOCKING HERE

    self._log.info(self.id, "barrier.. ")
    dist.barrier()

The self-contained code below works for me.

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP


def example(rank, world_size):
    # create default process group
    dist.init_process_group("gloo", rank=rank, world_size=world_size)
    # create local model
    model = nn.Linear(10, 10).to(rank)
    # construct DDP model
    ddp_model = DDP(model, device_ids=[rank])
    # define loss function and optimizer
    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    for _ in range((rank + 1) * 10):
        signal = torch.tensor([1])
        work = dist.all_reduce(signal, async_op=True)
        # forward pass
        outputs = ddp_model(torch.randn(20, 10).to(rank))
        labels = torch.randn(20, 10).to(rank)
        # backward pass
        work.wait()
        if signal.item() < world_size:
            break
        loss_fn(outputs, labels).backward()
        # update parameters
        optimizer.step()

    if signal.item() >= world_size:
        dist.all_reduce(torch.tensor([0]))

    dist.barrier()
    print(f"{rank} done")


def main():
    world_size = 2
    mp.spawn(example,
        args=(world_size,),
        nprocs=world_size,
        join=True)

if __name__=="__main__":
    main()

Thanks i finally made it work with your help!

1 Like

This doesnt work anymore if i use BatchNorm with GPU. If i disable track running stats then it works

It hangs on the line

    if signal.item() >= world_size:
        dist.all_reduce(torch.tensor([0]))

Yep, because BatchNorm would trigger DDP comm in forward as well. In that case, need to move the signal checking before forward, but it will be slower. The following code should work.

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP

def example(rank, world_size):
    # create default process group
    dist.init_process_group("gloo", rank=rank, world_size=world_size)
    # create local model
    model = nn.Sequential(
        nn.Linear(10, 10),
        nn.BatchNorm1d(10),
        nn.Linear(10, 10),
    ).to(rank)
    # construct DDP model
    ddp_model = DDP(model, device_ids=[rank])
    # define loss function and optimizer
    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    for _ in range((rank + 1) * 10):
        signal = torch.tensor([1])
        dist.all_reduce(signal)
        if signal.item() < world_size:
            break
        # forward pass
        outputs = ddp_model(torch.randn(20, 10).to(rank))
        labels = torch.randn(20, 10).to(rank)
        # backward pass
        loss_fn(outputs, labels).backward()
        # update parameters
        optimizer.step()

    if signal.item() >= world_size:
        dist.all_reduce(torch.tensor([0]))

    dist.barrier()
    print(f"{rank} done")


def main():
    world_size = 2
    mp.spawn(example,
        args=(world_size,),
        nprocs=world_size,
        join=True)

if __name__=="__main__":
    main()

1 Like

Its much slower. GPU Utilization is at 50 percent now. Is there any alternative solution to end process?

Ok i stand corrected. Slowdown comes from SyncBatchNorm

Is there any alternative solution to end process?

We are working on a more elegant and efficient solution. See the tracking issue: [RFC] Join-based API to support uneven inputs in DDP · Issue #38174 · pytorch/pytorch · GitHub

To unblock,

  1. If you know the number of inputs before entering the for loop, you can use an allreduce to get the min of that number across all processes. Then let all processes to just loop that min number of times in the loop.
  2. If this number cannot be acquired or is too expensive to acquire, you can first retrieve one input/batch/sample from the source/dataloader before entering the loop, say it’s called last_input. Then let each iteration
    1. retrieve a new input (call it curr_input) from the dataloder
    2. launch allreduce to check if all processes have a curr_input
    3. run forward-backward-opt on last_input
    4. wait on the all_reduce of the curr_input
    5. if all processes have curr_input proceed, assign curr_input to last_input; otherwise break.

Just want to confirm is it SyncBatchNorm or BatchNorm?

Okay, its slower when using SyncBatchNorm, but it is also slightly slower when using no BatchNorm at all

But maybe for some reason, we don’t have any input for a GPU device, or we just skip all of input for one device of them.

How can we use the last_input for all_reduce operation?

It doesn’t seem to work on my Ampere series GPU anymore with Pytorch 1.11 and CUDA 11.5. This chunk of code blocks on signal.item() when I run with 2 or more GPU

        for train_datum in self.train_dataset.iterate():
            if early_stop:
                break

            # If any dataloader in any process stops, we get a signal
            if self.cfg.distributed.enabled:
                signal = torch.tensor([1]).to(self.device)
                torch.distributed.all_reduce(signal)
                if signal.item() < self.cfg.distributed.workers: # BLOCKS ON SIGNAL.ITEM()
                    self.logger.info(self.id, "Termination Signal")
                    early_stop = True
                    continue