Accelerate backward in DDP REINFORCE

I am running on one 24-core CPU and one A100 GPU. Each subprocess runs the following (stripped down) code:

# general setup
dist.init_process_group('gloo', rank=rank, world_size=num_envs)

# set up local model and optimizer
local_model = DDP(model, device_ids=[device])
optim = Adam(local_model.parameters(), lr=.001)

save the trajectory from an episode.

- action_lgprobs[i] is the log 
  probability of the action sampled 
  at step i
- entropies[i] is the entropy of the
  distribution returned by the model
  at step i
- rewards[i] is the reward assigned to
  the i'th action by the environment

action_lgprobs and entropies are 
attached to the computational graph.
action_lgprobs, entropies, rewards = \

returns = compute_returns(rewards, discount)

# compute baselines
baselines = returns.clone()
baselines /= num_envs
advantages = returns - baselines

# compute loss
policy_loss = -advantages @ action_lgprobs
entropy_loss = entropy_weight * entropies.sum()
loss = policy_loss + entropy_loss

# update model

start = time.time()
loss.backward() # slow!
end = time.time()
print('BACKWARD TIME:', end - start)


I am finding the performance of loss.backward() to scale rather poorly with the number of subprocesses. E.g. with only one subprocess, it takes ~3s for backward to finish, while with 16 subprocesses it takes ~24s to finish on all of them (each subprocess prints the same duration). I’m wondering, is there anything I can do to speed it up? Any ideas will be appreciated.

Instead of using DPP, would it work to

  1. aggregate the loss using dist.all_reduce(loss)
  2. only run loss.backward on rank 0
  3. broadcast the gradients to all other ranks

Would the reduce operation preserve the computational graphs from all the ranks? It’s seems much faster to do a single backward() call, if possible.