Data Distributed Parallel Out of Sync

Recently, I have been running training of deep reinforcement learning policy network in a distributed manner with ddp. I am observing confusing behavior where the different ranks are at different learning updates at the same time. This does not m ake sense to me because I would think that the first call to loss.backward would wait for all ranks to get through the first batch before returning and continuing on to optimizer step.

A code snippet is below.

with contextlib.ExitStack() as stack:
            if self.problem.agent_train:
                stack.enter_context(self.agent_model.join())
            if self.problem.detector_train:
                stack.enter_context(self.detection_model.join())
            if self.problem.attacker_train:
                stack.enter_context(self.attacker_model.join())

 # Main training loop
        train_iterations = [0 for _ in memories]
        last_checkpoints = [time.time() for _ in memories]

        error = torch.zeros(1)
        error_signal = None
        if self.world_size  > 1:
            error_signal = dist.irecv(error, tag = 1)
        else:
            error_signal = None

        self.learning_steps_completed = 0
        while self.learning_steps_completed < self.n_learning_steps:
            # Check for timeout
            if (time.time() - start) > self.timeout - 120:
                print("User defined timeout reached... terminating training")
                break

            # Stop learning when an error occurs in an agent
            if self.error_event.is_set():
                had_error = True
                if self.verbose > 0:
                    print(f"[Rank {self.rank}] Error in agent", flush = True)
                for i in range(self.world_size):
                    if i != self.rank:
                        dist.send(error, dst = i, tag = 1)
                break

            if error_signal is not None and error_signal.is_completed():
                had_error = True
                print(f"[Rank {self.rank}] Ending early from error", flush = True)
                break

            # memories = [agent_memory, detection_memory, attacker_memory]
            # so this loop is just over the controller, detector, and attacker

            for i, memory in enumerate(memories):
                train_type, memory = memory
                if memory.has_batch():
                        self.learning_steps_completed += 1

                    # Sample the replay memory to get a batch of experiences (defined in child class of Memory)
                    batch, indexes, weights = memory.sample()

                    # Compute algorithm-specific loss (defined in child class)
                    loss, priorities = self.compute_loss(batch, weights, train_type)

                    # Do optimization step and increment the number of training iterations
                    self.optimizers[train_type].zero_grad()
                    loss.backward()
                    # dist.barrier() 
                    self.optimizers[train_type].step()
                    train_iterations[i] += 1

                    # Some periodic collective operations to collect intermediate results 

In the above code, the loss.backwards() is hit at different learning_steps_completed without the barrier but not with.

Could you explain what you mean by this? Are you printing self.learning_steps_completed after loss.backward on each process and you see different values across trainers where some are running ahead of others?

Note that if you are using GPUs here, gpu execution is async so essentially your training loops is just enqueuing work to the GPU and when loss.backward returns it means the work has been enqueued to the GPU but not necessarily finished. When you do dist.barrier() it forces a CUDA synchronization and blocks the host until all GPU work is done.

Yes, I’m seeing different learning steps on different ranks at once. A max difference is about 40.
I am not using a GPU at all, just cpu. Does the not ddp distributed calls cause problems for the context manager of join()?

Could it be possible this is just an output buffering issue? Could you add a sys.stdout.flush() line after your print statements and see if the output makes sense?

Does the not ddp distributed calls cause problems for the context manager of join()?

Not sure I followed this, are you referring to lines like self.agent_model.join() in your code? If so, is this the ddp.join() method or is this join() method doing something else?

No, it couldn’t be. I already flush the output.

Yes, it is the ddp.join(). Because there could be three model training together, I have to enter the the model join contexts.

If you are using ddp.join(), I’m assuming this is because you have uneven data across your trainers. If so, isn’t it expected to see different number of training steps across different trainers?

Could you also provide a complete repro script that we can try on our end and also the output that you are seeing on your end. Would be much easier to troubleshoot this issue for us that way.