Distributed RL training

Hi Everyone, first post :slight_smile:

I am working on a robotics project where a Rl Agent (DDPG/SAC/…) interacts with many environments, which are run on multiple processors. The responses by these environments are collected in a batch wise manner in each run of the main loop of the program. In each of these loops, after processing the environments responses, agent.train() is called once.
My problem is, that if i increase the number of environments, the agent would collect more experience which is created by actions from a more novice agent than with less environments, i.e. the ratio of steps/unit_of_training increases. A remedy for this would be using more GPU’s per unit_of_training, thereby increasing my sweeps over the replay buffer and bringing that ratio down again, in particular i want to spawn n workers, each having their own slice of the replay buffer sample, which then collectively update either my loss val or gradients (I think this is dependent on the type of optimizer).
Is it possible to use the DistributedDataParallel (DDP) class with Reinforcement Learning, where we need the multiprocessing workers only for training after interaction with the environments?
Do I need to choose a specific optimizer (e.g. I’ve seen SGD used a lot in conjunction with DDP) or can i choose any Optimizer?

Furthermore, I’m having trouble with the various tutorials online, stating that i should use the DistributedSampler class. How can I integrate a normal rl replay buffer with this, is it possible to update the DistributedSampler memory at runtime or use DDP without DistributedSampler?

All the best

Edit: the weight update is not touched by distributed gradients: Comparison Data Parallel Distributed data parallel - #4 by mrshenli

Hey @Thunfisch, I am not familiar with RL, and cannot comment on the optimizer part, but I can share some pointers on using DDP/RPC for RL.

also cc @VitalyFedyunin for DataLoader/Sampler

DDP applications usually have local loss tensors, and only gradients are synchronized during the backward pass. If you need to compute a global loss, you can combine DDP with RPC, using DDP to synchronize gradients and using RPC to calculate global loss and properly propagate gradients back to each DDP instance. Here are pointers:

  1. RPC RL example: Implementing Batch RPC Processing Using Asynchronous Executions — PyTorch Tutorials 1.7.1 documentation
  2. Combine DDP with RPC: Combining Distributed DataParallel with Distributed RPC Framework — PyTorch Tutorials 1.7.1 documentation
1 Like