Recently, I have been running training of deep reinforcement learning policy network in a distributed manner with ddp. I am observing confusing behavior where the different ranks are at different learning updates at the same time. This does not m ake sense to me because I would think that the first call to loss.backward would wait for all ranks to get through the first batch before returning and continuing on to optimizer step.
A code snippet is below.
with contextlib.ExitStack() as stack:
if self.problem.agent_train:
stack.enter_context(self.agent_model.join())
if self.problem.detector_train:
stack.enter_context(self.detection_model.join())
if self.problem.attacker_train:
stack.enter_context(self.attacker_model.join())
# Main training loop
train_iterations = [0 for _ in memories]
last_checkpoints = [time.time() for _ in memories]
error = torch.zeros(1)
error_signal = None
if self.world_size > 1:
error_signal = dist.irecv(error, tag = 1)
else:
error_signal = None
self.learning_steps_completed = 0
while self.learning_steps_completed < self.n_learning_steps:
# Check for timeout
if (time.time() - start) > self.timeout - 120:
print("User defined timeout reached... terminating training")
break
# Stop learning when an error occurs in an agent
if self.error_event.is_set():
had_error = True
if self.verbose > 0:
print(f"[Rank {self.rank}] Error in agent", flush = True)
for i in range(self.world_size):
if i != self.rank:
dist.send(error, dst = i, tag = 1)
break
if error_signal is not None and error_signal.is_completed():
had_error = True
print(f"[Rank {self.rank}] Ending early from error", flush = True)
break
# memories = [agent_memory, detection_memory, attacker_memory]
# so this loop is just over the controller, detector, and attacker
for i, memory in enumerate(memories):
train_type, memory = memory
if memory.has_batch():
self.learning_steps_completed += 1
# Sample the replay memory to get a batch of experiences (defined in child class of Memory)
batch, indexes, weights = memory.sample()
# Compute algorithm-specific loss (defined in child class)
loss, priorities = self.compute_loss(batch, weights, train_type)
# Do optimization step and increment the number of training iterations
self.optimizers[train_type].zero_grad()
loss.backward()
# dist.barrier()
self.optimizers[train_type].step()
train_iterations[i] += 1
# Some periodic collective operations to collect intermediate results
In the above code, the loss.backwards() is hit at different learning_steps_completed without the barrier but not with.