Torch multiprocessing


I have some RL code implemented and am using torch.multiprocessing to collect training samples. The code runs fine but my challenge is that I want to run a separate function every n episodes to check performance metrics of current trained model, however, i cannot seem to do this. The broad structure of the multiprocessing code is as follows, the part I am stuck at is the last line in the run method (see code comment in the validation portion):

import torch.multiprocessing as mp

class Worker(mp.Process):
   def __init__(self, global_ep,  res_queue):
      self.global_ep  = global_ep
      self.res_queue = res_queue

   def record(self, global_ep):
      with self.global_ep.get_lock():
         self.global_ep += 1
   def run(self):
      while self.global_ep < MAX_EPISODE:
         # run training loop

          if self.global_ep % VERBOSE_FREQ ==0:
             # check model performance on validation data
            """  run is called on each processor separately, however,
                 I  only want to use the global trained model to run inference once, after all processors
                have collected training samples"""
global_ep, res_queue = mp.Value('i', 0),  mp.Queue()

workers = [Worker(global_ep, res_queue)]
[w.start() for w in workers]
[w.join() for w in workers]

If i run the validate code in the portion with the comment, it runs once for each active processor since run is called individually on each processor.

Any ideas?


Hi @ssha
Are you training the model on each process or to collect samples?
If not, I would simply create an env instance in the main process and test your model on it. Is it something you have considered?
Have a look at torchrl’s trainer: what we do is that we use a data collector on parallel processes that sends batches of data to the main process. The main process will take care of training the model. You can register a recording hook that will evaluate your model on an environment in the main process, i.e. distant processes are not involved.
Hope it helps!

You could use a shared memory object (such as a single dimensional tensor corresponding to the number of worker processes you have) that gets updated by each worker for their completed episodes. Basically a counter for each episode.

Then have a separate process (or your main process) that checks that shared memory tensor and runs metrics any time it hits your N episodes.