Multiprocessing - Barrier Blocks all Processes?

Actually, even when i simplified my code down to the basics. just launching two processes for two GPU’s, as soon as the first process, hits the barrier, it stops the other process.

But, if i remove the loss.backward() and the optimizer stuff, it works

You don’t need to share the private code, a min repro would help.

But, if i remove the loss.backward() and the optimizer stuff, it works

In this case I assume you are using DistributedDataParallel (DDP)? DDP calls allreduce internally. If the first process is also in the same process group but is not running backward on DDP model, other processes would hang on backward, because they need 1st process to join the allreduce.

One requirement of collective communications is that all members need to call the same collective API in the same order.

If you need some side channel to do the barrier, you can create a new process group using the new_group API, and then call barrier on that.

1 Like

wait…let me get this straight:

so lets say one process calls dist.barrier(), then the other process will be stuck on loss.backward()??

so lets say one process calls dist.barrier(), then the other process will be stuck on loss.backward()??

Yep, that’s possible if they are using the same process group and if you are using DistributedDataParallel. That’s the contract of collective communications. If process x is calling allreduce->allreduce->barrier, and process Y is calling barrier, the allrecue would block on process x, as process Y never joined the collective comm.

Yes they are in the same process group and yes I am using DistributedDataParallel

Here is a quick code snippet and output

    if self.id == 0:
        ep = 10
    elif self.id == 1:
        ep = 20

    for i in range(0, ep):

        self._log.info(self.id, "image.. " + str(i))
        image = torch.zeros((1,2,3,256,384)).cuda(self.id)
        input = {"rgb": image}
        self._log.info(self.id, "model.. " + str(i))
        output = self.model(input)
        loss = torch.sum(output["output"][-1])
        time.sleep(0.1)

        self._log.info(self.id, "zerograd.. " + str(i))
        self.optimizer.zero_grad()
        self._log.info(self.id, "backward.. " + str(i))
        loss.backward()
        self._log.info(self.id, "step.. " + str(i))
        self.optimizer.step()
        #print("step: " + str(self.id) + " " + str(i))
        self._log.info(self.id, "done " + str(i))

    self._log.info(self.id, "barrier.. ")
    dist.barrier()

Here, I am having two processes run. And one finishes faster than the other and needs to wait on the first one. But instead as soon as 0 calls barrier, it blocks 1 completely.

Here is the output

[INFO]  [0] zerograd.. 9
[INFO]  [0] backward.. 9
[INFO]  [0] step.. 9
[INFO]  [1] model.. 10
[INFO]  [0] done 9
[INFO]  [0] barrier.. 
[INFO]  [1] zerograd.. 10
[INFO]  [1] backward.. 10
[INFO]  [1] step.. 10
[INFO]  [1] done 10
[INFO]  [1] image.. 11

As you can see process 1 is stuck on image…

If i change the code so both have the same number of epochs in the for loop, then it works

I see, that’s the uneven DDP input batches issue. See the previous tracking issue here. We have just finished design discussion on this, and plan to create a ddp.join() context manager to solve this problem. @rvarm1 will post a design RFC later.

For short-term unblock, you can either make sure that all DDP instances are running through the same number of batches or detect that proactively in the application, sth like:

for batch in get_batch():
    signal = torch.tensor([1])
    work = all_reduce(signal, async_op=True)
    loss = model(batch).sum()
    work.wait()
    if signal.item() < world_size:
        break
    loss.backward()


all_reduce(torch.tensor([0]))

oh wow…it took a lot of effort to finally find the right github issue. thanks.

can you explain more what the all_reduce operation is doing? Why are you doing it to a torch.tensor([1])?

here is the allreduce API doc. By default, it is summing the provided tensor across all processes. The above code snippet is using allreduce to detect if any process has finished processing all inputs. If they are still in the loop, they use all_reduce(1) to add one to the sum, otherwise, they do all_reduce(0). So as long as the all_reduce sum is smaller than the world_size, it means some process has exited the loop. And in this case, no other process should launch the backward, otherwise the allreduce in the backward would hang.

Above is basically detecting and then skipping remaining batches.

1 Like

Can i check. Instead of using this approach:

Can I create a shared memory torch tensor. And as soon as one process finishes its batch, i can write to the tensor, and have the other processes read it so it can break it’s loop

Sure, that will also work. You can use a torch.mutliprocessing.Queue to do that, but this solution cannot scale across multiple machines.

Final question. Not clear why you have another all reduce at the end

Hmm…i tried this method and it works. But now the second process hangs at dist.barrier().

Final question. Not clear why you have another all reduce at the end

That’s to indicate one process has exited the loop, otherwise the sum would never be smaller than the world_size. It needs to be guarded by a flag, because only the processes that exit early need that. Sth like:

for batch in get_batch():
    ....

if signal.item() >= world_size:
    all_reduce(torch.tensor([0]))

I tried implement it as you suggested, but my second process gets stuck on the barrier

    if self.id == 0:
        ep = 10
        sl = 0.1
    elif self.id == 1:
        ep = 15
        sl = 0.2

    for i in range(0, ep):

        # if i == 10:
        #     break

        signal = torch.tensor([1]).to(self.device)
        work = dist.all_reduce(signal, async_op=True)

        self._log.info(self.id, "image.. " + str(i))
        image = torch.zeros((1,2,3,256,384)).cuda(self.id)
        input = {"rgb": image}
        output = self.model(input)
        loss = torch.sum(output["output"][-1])
        
        work.wait()
        if signal.item() < self.cfg.mp.workers:
            break

        self.optimizer.zero_grad()
        self._log.info(self.id, "backward.. " + str(i))
        loss.backward()
        self.optimizer.step()
        print("step: " + str(self.id) + " " + str(i))
        self._log.info(self.id, "done " + str(i))

        time.sleep(sl)

    self._log.info(self.id, "reduce.. ")
    signal = torch.tensor([0]).to(self.device)
    dist.all_reduce(signal)

    self._log.info(self.id, "barrier.. ")
    dist.barrier()

    return

the final all reduce should only be done by the process that first exits. How do i ensure that?

You can guard that with a flag. For the first process that exits (call it X), its signal tensor should always equal to world_size. But for other processes, they would see signal.item() < world_size, because their all_reduce in the for loop will join with the after-loop all_reduce on X.

I tried to do that but it doesnt work. You said that other proicesses would see signal.item() < world_size, but the first process shows the signal to be 0. So it never blocks on all_reduce. Can you tell me where im going wrong?

    if self.id == 0:
        ep = 10
        sl = 0.1
    elif self.id == 1:
        ep = 15
        sl = 0.2

    for i in range(0, ep):

        # if i == 10:
        #     break

        signal = torch.tensor([1]).to(self.device)
        work = dist.all_reduce(signal, async_op=True)

        #self._log.info(self.id, "image.. " + str(i))
        image = torch.zeros((1,2,3,256,384)).cuda(self.id)
        input = {"rgb": image}
        output = self.model(input)
        loss = torch.sum(output["output"][-1])
        
        work.wait()
        self._log.info(self.id, "SIG: " + str(signal.item()))
        if signal.item() < self.cfg.mp.workers:
            self._log.info(self.id, "EXIT: " + str(signal.item()))
            break

        self.optimizer.zero_grad()
        #self._log.info(self.id, "backward.. " + str(i))
        loss.backward()
        self.optimizer.step()
        #print("step: " + str(self.id) + " " + str(i))
        self._log.info(self.id, "done " + str(i))

        time.sleep(sl)

    self._log.info(self.id, "reduce.. ")
    signal = torch.tensor([0]).to(self.device)
    self._log.info(self.id, "FIANL SIGNAL: " + str(signal.item()))
    if signal.item() >= self.cfg.mp.workers:
        dist.all_reduce(signal) # NOT BLOCKING HERE

    self._log.info(self.id, "barrier.. ")
    dist.barrier()

The self-contained code below works for me.

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP


def example(rank, world_size):
    # create default process group
    dist.init_process_group("gloo", rank=rank, world_size=world_size)
    # create local model
    model = nn.Linear(10, 10).to(rank)
    # construct DDP model
    ddp_model = DDP(model, device_ids=[rank])
    # define loss function and optimizer
    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    for _ in range((rank + 1) * 10):
        signal = torch.tensor([1])
        work = dist.all_reduce(signal, async_op=True)
        # forward pass
        outputs = ddp_model(torch.randn(20, 10).to(rank))
        labels = torch.randn(20, 10).to(rank)
        # backward pass
        work.wait()
        if signal.item() < world_size:
            break
        loss_fn(outputs, labels).backward()
        # update parameters
        optimizer.step()

    if signal.item() >= world_size:
        dist.all_reduce(torch.tensor([0]))

    dist.barrier()
    print(f"{rank} done")


def main():
    world_size = 2
    mp.spawn(example,
        args=(world_size,),
        nprocs=world_size,
        join=True)

if __name__=="__main__":
    main()

Thanks i finally made it work with your help!

1 Like