Multiprocessing / DDP - Barrier Blocks loss.backward()

I have a toy example for DDP. I have two processes, and one process has more “iterations” than the other

    if self.id == 0:
        ep = 10
        sl = 0.1
    elif self.id == 1:
        ep = 15
        sl = 0.2

    for i in range(0, ep):

        # if i == 10:
        #     break

        self._log.info(self.id, "image.. " + str(i))
        image = torch.zeros((1,2,3,256,384))
        input = {"rgb": image}
        self._log.info(self.id, "model.. " + str(i))
        output = self.model(input)
        loss = torch.sum(output["output"][-1])
        

        self._log.info(self.id, "zerograd.. " + str(i))
        self.optimizer.zero_grad()
        self._log.info(self.id, "backward.. " + str(i))
        loss.backward()
        self._log.info(self.id, "step.. " + str(i))
        self.optimizer.step()
        print("step: " + str(self.id) + " " + str(i))
        self._log.info(self.id, "done " + str(i))

        time.sleep(sl)

    self._log.info(self.id, "barrier.. ")
    dist.barrier()

But as soon as the first process (with lesser iterations) hits dist.barrier(), loss.backward() is blocked in the 2nd process! How do I get around this? I could force all processes to have the same number of iterations…but thats not something i want.

Is this the same as Multiprocessing - Barrier Blocks all Processes?

BTW, for future questions on torch.distributed, could you please add a “distributed” tag. People working on the distributed package are monitoring that channel.