I am trying to implement a very basic version of the “Asynchronous one-step Q-learning” (page 3). I therefore need to train a neural network simultaneously on several processes (or threads, not sure yet).
The different process needs to use the same optimizer. There is a local network and a target network that gets updated every N steps (in my small code it gets updated but not used for simplicity sakes).
The overall system uses the Hogwild! methods, so there is in theory no need to do much locking from what I have understand
This is my small snippet to try to understand how I can implement these mechanics:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
INPUT_DIMENSION = 10
OUTPUT_DIMENSION = 4
OPTIMIZER_STEP_FREQUENCY = 10
UPDATE_TARGET_NETWORK_FREQUENCY = 20
class Worker:
def __init__(self, online_network, target_network, optimizer):
self.optimizer = optimizer
self.online_network = online_network
self.target_network = target_network
def run(self, global_step, num_steps):
for i in range(num_steps):
with global_step.get_lock():
global_step.value += 1
data = torch.ones((1, INPUT_DIMENSION))
prediction = self.online_network(data)
target = -torch.ones((1, OUTPUT_DIMENSION))
loss = nn.MSELoss()(prediction, target)
loss.backward()
if i % OPTIMIZER_STEP_FREQUENCY == 0:
self.optimizer_step()
if i % UPDATE_TARGET_NETWORK_FREQUENCY == 0:
self.update_target_network()
def optimizer_step(self):
self.optimizer.step()
self.optimizer.zero_grad()
def update_target_network(self):
self.target_network.load_state_dict(self.online_network.state_dict())
if __name__ == '__main__':
online_network = nn.Linear(INPUT_DIMENSION, OUTPUT_DIMENSION)
online_network.share_memory()
target_network = nn.Linear(INPUT_DIMENSION, OUTPUT_DIMENSION)
target_network.load_state_dict(online_network.state_dict())
target_network.share_memory().eval()
global_step = mp.Value('i', 0)
optimizer = optim.SGD(online_network.parameters(), lr=0.005)
num_processes = 4
num_steps_per_worker = 30
processes = []
for rank in range(num_processes):
p = mp.Process(target=Worker(online_network, target_network, optimizer).run,
args=(global_step, num_steps_per_worker,))
p.start()
processes.append(p)
for p in processes:
p.join()
print(global_step.value)
print(online_network(torch.ones((1, INPUT_DIMENSION))).tolist())
I wanted to know if my way of handling the different variables and networks is okay. I am new to multiprocessing and I am not sure if what I am doing is “good practice”.
Also I saw on repositories that a custom function is used where the optimizer is “wrapped” to share it. Should I use such class for my application ? Is there a better way to do that (In the newer versions of Pytorch) ?
Thanks!