Hi,
I have some RL code implemented and am using torch.multiprocessing to collect training samples. The code runs fine but my challenge is that I want to run a separate function every n episodes to check performance metrics of current trained model, however, i cannot seem to do this. The broad structure of the multiprocessing code is as follows, the part I am stuck at is the last line in the run method (see code comment in the validation portion):
import torch.multiprocessing as mp
class Worker(mp.Process):
def __init__(self, global_ep, res_queue):
self.global_ep = global_ep
self.res_queue = res_queue
def record(self, global_ep):
with self.global_ep.get_lock():
self.global_ep += 1
def run(self):
while self.global_ep < MAX_EPISODE:
# run training loop
if self.global_ep % VERBOSE_FREQ ==0:
# check model performance on validation data
""" run is called on each processor separately, however,
I only want to use the global trained model to run inference once, after all processors
have collected training samples"""
global_ep, res_queue = mp.Value('i', 0), mp.Queue()
workers = [Worker(global_ep, res_queue)]
[w.start() for w in workers]
[w.join() for w in workers]
If i run the validate code in the portion with the comment, it runs once for each active processor since run is called individually on each processor.
Any ideas?
Thanks!