I tried to implement board game self-play data generation in parallel using multiple cpus to do self-paly concurrently. For parent process, i created 4 NN model for 30cpus (1 model for 10 cpus and 1 model to train) each model is in different gpus.(the model is implemented as 20 blocks resnet-like architecture with batchnorm) Pseudo code as follows
nnet = NN(gpu_num=0) nnet1 = NN(gpu_num=1) nnet2 = NN(gpu_num=2) nnet3 = NN(gpu_num=3) for i in range(num_iteration): nnet1.load_state_dict(nnet.state_dict()) nnet2.load_state_dict(nnet.state_dict()) nnet3.load_state_dict(nnet.state_dict()) samples = parallel_self_play() nnet.train(samples)
parallel_self_play() is implemented as follows
pool = mp.Pool(processes=num_cpu) #30 for i in range(self.args.numEps): results =  if i % 3 == 0: net = self.nnet1 elif i % 3 == 1: net = self.nnet2 else: net = self.nnet3 results.append(pool.apply_async(AsyncSelfPlay, args=(net)) # get results from results array then return it return results
My code work perfectly fine with almost 100% gpu utilization throughout the first self-play (less than 10 minutes for an iteration) but after the first iteration (training) when i loaded new weights into nnet1-3 gpu utilization never reach 80% again (~30min - 1hour per iteration). I notice a few things while mess around with me code
This model includes batchnorm layers, when switch model to train() mode -> train -> switch back to eval() causes the self-play (use forward pass from model) to not use gpu at all.
If it doesn’t switch from eval() -> train() (train using eval mode) this causes gpu utilization to be lower (30-50%) but not entirely gone.
If the models that are not the main one doesn’t load the weights from the main one, self-play still utilize 100% gpu so my guess is that something happened during training process and change some states in the model.
This also happen when use only 8 cpus - 1gpu architecture and train model on the fly (no intermediate one).
Can someone guide me where to fix my code or how i should train my model?