I have a toy example for DDP. I have two processes, and one process has more “iterations” than the other
if self.id == 0:
ep = 10
sl = 0.1
elif self.id == 1:
ep = 15
sl = 0.2
for i in range(0, ep):
# if i == 10:
# break
self._log.info(self.id, "image.. " + str(i))
image = torch.zeros((1,2,3,256,384))
input = {"rgb": image}
self._log.info(self.id, "model.. " + str(i))
output = self.model(input)
loss = torch.sum(output["output"][-1])
self._log.info(self.id, "zerograd.. " + str(i))
self.optimizer.zero_grad()
self._log.info(self.id, "backward.. " + str(i))
loss.backward()
self._log.info(self.id, "step.. " + str(i))
self.optimizer.step()
print("step: " + str(self.id) + " " + str(i))
self._log.info(self.id, "done " + str(i))
time.sleep(sl)
self._log.info(self.id, "barrier.. ")
dist.barrier()
But as soon as the first process (with lesser iterations) hits dist.barrier(), loss.backward() is blocked in the 2nd process! How do I get around this? I could force all processes to have the same number of iterations…but thats not something i want.