Multiprocessing - Barrier Blocks all Processes?

Can i check. Instead of using this approach:

Can I create a shared memory torch tensor. And as soon as one process finishes its batch, i can write to the tensor, and have the other processes read it so it can break it’s loop

Sure, that will also work. You can use a torch.mutliprocessing.Queue to do that, but this solution cannot scale across multiple machines.

Final question. Not clear why you have another all reduce at the end

Hmm…i tried this method and it works. But now the second process hangs at dist.barrier().

Final question. Not clear why you have another all reduce at the end

That’s to indicate one process has exited the loop, otherwise the sum would never be smaller than the world_size. It needs to be guarded by a flag, because only the processes that exit early need that. Sth like:

for batch in get_batch():
    ....

if signal.item() >= world_size:
    all_reduce(torch.tensor([0]))

I tried implement it as you suggested, but my second process gets stuck on the barrier

    if self.id == 0:
        ep = 10
        sl = 0.1
    elif self.id == 1:
        ep = 15
        sl = 0.2

    for i in range(0, ep):

        # if i == 10:
        #     break

        signal = torch.tensor([1]).to(self.device)
        work = dist.all_reduce(signal, async_op=True)

        self._log.info(self.id, "image.. " + str(i))
        image = torch.zeros((1,2,3,256,384)).cuda(self.id)
        input = {"rgb": image}
        output = self.model(input)
        loss = torch.sum(output["output"][-1])
        
        work.wait()
        if signal.item() < self.cfg.mp.workers:
            break

        self.optimizer.zero_grad()
        self._log.info(self.id, "backward.. " + str(i))
        loss.backward()
        self.optimizer.step()
        print("step: " + str(self.id) + " " + str(i))
        self._log.info(self.id, "done " + str(i))

        time.sleep(sl)

    self._log.info(self.id, "reduce.. ")
    signal = torch.tensor([0]).to(self.device)
    dist.all_reduce(signal)

    self._log.info(self.id, "barrier.. ")
    dist.barrier()

    return

the final all reduce should only be done by the process that first exits. How do i ensure that?

You can guard that with a flag. For the first process that exits (call it X), its signal tensor should always equal to world_size. But for other processes, they would see signal.item() < world_size, because their all_reduce in the for loop will join with the after-loop all_reduce on X.

I tried to do that but it doesnt work. You said that other proicesses would see signal.item() < world_size, but the first process shows the signal to be 0. So it never blocks on all_reduce. Can you tell me where im going wrong?

    if self.id == 0:
        ep = 10
        sl = 0.1
    elif self.id == 1:
        ep = 15
        sl = 0.2

    for i in range(0, ep):

        # if i == 10:
        #     break

        signal = torch.tensor([1]).to(self.device)
        work = dist.all_reduce(signal, async_op=True)

        #self._log.info(self.id, "image.. " + str(i))
        image = torch.zeros((1,2,3,256,384)).cuda(self.id)
        input = {"rgb": image}
        output = self.model(input)
        loss = torch.sum(output["output"][-1])
        
        work.wait()
        self._log.info(self.id, "SIG: " + str(signal.item()))
        if signal.item() < self.cfg.mp.workers:
            self._log.info(self.id, "EXIT: " + str(signal.item()))
            break

        self.optimizer.zero_grad()
        #self._log.info(self.id, "backward.. " + str(i))
        loss.backward()
        self.optimizer.step()
        #print("step: " + str(self.id) + " " + str(i))
        self._log.info(self.id, "done " + str(i))

        time.sleep(sl)

    self._log.info(self.id, "reduce.. ")
    signal = torch.tensor([0]).to(self.device)
    self._log.info(self.id, "FIANL SIGNAL: " + str(signal.item()))
    if signal.item() >= self.cfg.mp.workers:
        dist.all_reduce(signal) # NOT BLOCKING HERE

    self._log.info(self.id, "barrier.. ")
    dist.barrier()

The self-contained code below works for me.

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP


def example(rank, world_size):
    # create default process group
    dist.init_process_group("gloo", rank=rank, world_size=world_size)
    # create local model
    model = nn.Linear(10, 10).to(rank)
    # construct DDP model
    ddp_model = DDP(model, device_ids=[rank])
    # define loss function and optimizer
    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    for _ in range((rank + 1) * 10):
        signal = torch.tensor([1])
        work = dist.all_reduce(signal, async_op=True)
        # forward pass
        outputs = ddp_model(torch.randn(20, 10).to(rank))
        labels = torch.randn(20, 10).to(rank)
        # backward pass
        work.wait()
        if signal.item() < world_size:
            break
        loss_fn(outputs, labels).backward()
        # update parameters
        optimizer.step()

    if signal.item() >= world_size:
        dist.all_reduce(torch.tensor([0]))

    dist.barrier()
    print(f"{rank} done")


def main():
    world_size = 2
    mp.spawn(example,
        args=(world_size,),
        nprocs=world_size,
        join=True)

if __name__=="__main__":
    main()

Thanks i finally made it work with your help!

1 Like

This doesnt work anymore if i use BatchNorm with GPU. If i disable track running stats then it works

It hangs on the line

    if signal.item() >= world_size:
        dist.all_reduce(torch.tensor([0]))

Yep, because BatchNorm would trigger DDP comm in forward as well. In that case, need to move the signal checking before forward, but it will be slower. The following code should work.

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP

def example(rank, world_size):
    # create default process group
    dist.init_process_group("gloo", rank=rank, world_size=world_size)
    # create local model
    model = nn.Sequential(
        nn.Linear(10, 10),
        nn.BatchNorm1d(10),
        nn.Linear(10, 10),
    ).to(rank)
    # construct DDP model
    ddp_model = DDP(model, device_ids=[rank])
    # define loss function and optimizer
    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    for _ in range((rank + 1) * 10):
        signal = torch.tensor([1])
        dist.all_reduce(signal)
        if signal.item() < world_size:
            break
        # forward pass
        outputs = ddp_model(torch.randn(20, 10).to(rank))
        labels = torch.randn(20, 10).to(rank)
        # backward pass
        loss_fn(outputs, labels).backward()
        # update parameters
        optimizer.step()

    if signal.item() >= world_size:
        dist.all_reduce(torch.tensor([0]))

    dist.barrier()
    print(f"{rank} done")


def main():
    world_size = 2
    mp.spawn(example,
        args=(world_size,),
        nprocs=world_size,
        join=True)

if __name__=="__main__":
    main()

1 Like

Its much slower. GPU Utilization is at 50 percent now. Is there any alternative solution to end process?

Ok i stand corrected. Slowdown comes from SyncBatchNorm

Is there any alternative solution to end process?

We are working on a more elegant and efficient solution. See the tracking issue: https://github.com/pytorch/pytorch/issues/38174

To unblock,

  1. If you know the number of inputs before entering the for loop, you can use an allreduce to get the min of that number across all processes. Then let all processes to just loop that min number of times in the loop.
  2. If this number cannot be acquired or is too expensive to acquire, you can first retrieve one input/batch/sample from the source/dataloader before entering the loop, say it’s called last_input. Then let each iteration
    1. retrieve a new input (call it curr_input) from the dataloder
    2. launch allreduce to check if all processes have a curr_input
    3. run forward-backward-opt on last_input
    4. wait on the all_reduce of the curr_input
    5. if all processes have curr_input proceed, assign curr_input to last_input; otherwise break.

Just want to confirm is it SyncBatchNorm or BatchNorm?

Okay, its slower when using SyncBatchNorm, but it is also slightly slower when using no BatchNorm at all

But maybe for some reason, we don’t have any input for a GPU device, or we just skip all of input for one device of them.

How can we use the last_input for all_reduce operation?

It doesn’t seem to work on my Ampere series GPU anymore with Pytorch 1.11 and CUDA 11.5. This chunk of code blocks on signal.item() when I run with 2 or more GPU

        for train_datum in self.train_dataset.iterate():
            if early_stop:
                break

            # If any dataloader in any process stops, we get a signal
            if self.cfg.distributed.enabled:
                signal = torch.tensor([1]).to(self.device)
                torch.distributed.all_reduce(signal)
                if signal.item() < self.cfg.distributed.workers: # BLOCKS ON SIGNAL.ITEM()
                    self.logger.info(self.id, "Termination Signal")
                    early_stop = True
                    continue