BatchNorm hangs with save and load state_dict while training with multi-processes


In my implementation, I train the model by multiple processes and save its state dictionary by a concurrent process to evaluate/test it after the training is complete. During evaluation, I load the saved state dictionaries. I need to do this in order to compute some more info from my distributed/concurrent training algorithm. It also helps me in un-engaging the processors from test-computation-load while training.

This implementation works perfectly for all the models that I worked with when it is done on a CPU. On a GPU, when the number of training processes is one, again it is fine for each of those models. Furthermore, it also works fine if there is no BatchNorm in the model and I train it using multiple processes over a GPU.

However, with BatchNorm and multiple processes training on a GPU, when I do a forward pass later during evaluation it hangs at Conv2d (it might be at some other place further inside, however, I could pin in my debugger up to Conv2d).

What does exactly happen with the BatchNorm+multiprocessing combination?