Multiprocessing and Reinforcement Learning

I am trying to implement a very basic version of the “Asynchronous one-step Q-learning” (page 3). I therefore need to train a neural network simultaneously on several processes (or threads, not sure yet).

The different process needs to use the same optimizer. There is a local network and a target network that gets updated every N steps (in my small code it gets updated but not used for simplicity sakes).

The overall system uses the Hogwild! methods, so there is in theory no need to do much locking from what I have understand

This is my small snippet to try to understand how I can implement these mechanics:

import torch
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp


INPUT_DIMENSION = 10
OUTPUT_DIMENSION = 4

OPTIMIZER_STEP_FREQUENCY = 10
UPDATE_TARGET_NETWORK_FREQUENCY = 20


class Worker:
    def __init__(self, online_network, target_network, optimizer):
        self.optimizer = optimizer
        self.online_network = online_network
        self.target_network = target_network

    def run(self, global_step, num_steps):
        for i in range(num_steps):

            with global_step.get_lock():
                global_step.value += 1

            data = torch.ones((1, INPUT_DIMENSION))
            prediction = self.online_network(data)
            target = -torch.ones((1, OUTPUT_DIMENSION))

            loss = nn.MSELoss()(prediction, target)
            loss.backward()

            if i % OPTIMIZER_STEP_FREQUENCY == 0:
                self.optimizer_step()

            if i % UPDATE_TARGET_NETWORK_FREQUENCY == 0:
                self.update_target_network()

    def optimizer_step(self):
        self.optimizer.step()
        self.optimizer.zero_grad()

    def update_target_network(self):
        self.target_network.load_state_dict(self.online_network.state_dict())


if __name__ == '__main__':
    online_network = nn.Linear(INPUT_DIMENSION, OUTPUT_DIMENSION)
    online_network.share_memory()

    target_network = nn.Linear(INPUT_DIMENSION, OUTPUT_DIMENSION)
    target_network.load_state_dict(online_network.state_dict())
    target_network.share_memory().eval()

    global_step = mp.Value('i', 0)
    optimizer = optim.SGD(online_network.parameters(), lr=0.005)

    num_processes = 4
    num_steps_per_worker = 30

    processes = []
    for rank in range(num_processes):
        p = mp.Process(target=Worker(online_network, target_network, optimizer).run,
                       args=(global_step, num_steps_per_worker,))
        p.start()
        processes.append(p)
    for p in processes:
        p.join()

    print(global_step.value)
    print(online_network(torch.ones((1, INPUT_DIMENSION))).tolist())

I wanted to know if my way of handling the different variables and networks is okay. I am new to multiprocessing and I am not sure if what I am doing is “good practice”.

Also I saw on repositories that a custom function is used where the optimizer is “wrapped” to share it. Should I use such class for my application ? Is there a better way to do that (In the newer versions of Pytorch) ?

Thanks!

Hi,

In general the multiprocessing setup looks good to me.

It looks like on the link you mentioned, the authors have implemented ShareAdam: https://github.com/g6ling/Reinforcement-Learning-Pytorch-Cartpole/blob/master/parallel/1-Async-Q-Learning/shared_adam.py to share gradients across processes. If your use case requires this too, this is probably a good approach, as PyTorch does not currently natively support sharing gradients across processes in optimizers.

1 Like

Thanks a lot for the feedback!

I’m seeing that you are quite experienced with distributed system; Do you know how much the optimizer cancel each other e.g. If one optimizer does a step and is reset, does it also cancel the accumulation of the other optimizers ? I saw a thread about that and I am quite curious about it

Also, what do you think is the best way to retrieve the global network parameters:
using .share_memory() and retrieve it from time to time, or using a multiprocessing.Manager() like in this implementation ?

mp_manager = mp.Manager()
shared_state = mp_manager.dict()
shared_state["Q_network_state_dict"] = q_network.state_dict()