Torch multiprocessing

Hi,

I have some RL code implemented and am using torch.multiprocessing to collect training samples. The code runs fine but my challenge is that I want to run a separate function every n episodes to check performance metrics of current trained model, however, i cannot seem to do this. The broad structure of the multiprocessing code is as follows, the part I am stuck at is the last line in the run method (see code comment in the validation portion):

import torch.multiprocessing as mp



class Worker(mp.Process):
   def __init__(self, global_ep,  res_queue):
      self.global_ep  = global_ep
      self.res_queue = res_queue

   def record(self, global_ep):
      with self.global_ep.get_lock():
         self.global_ep += 1
   
   def run(self):
      while self.global_ep < MAX_EPISODE:
         # run training loop

          if self.global_ep % VERBOSE_FREQ ==0:
             # check model performance on validation data
            """  run is called on each processor separately, however,
                 I  only want to use the global trained model to run inference once, after all processors
                have collected training samples"""
  
global_ep, res_queue = mp.Value('i', 0),  mp.Queue()

workers = [Worker(global_ep, res_queue)]
[w.start() for w in workers]
[w.join() for w in workers]

If i run the validate code in the portion with the comment, it runs once for each active processor since run is called individually on each processor.

Any ideas?

Thanks!