Funky batchnorm behavior with policy gradient


I’ve been working on an episodic policy gradient problem using both CNN’s and simpler linear models as the agent, and I’ve noticed that for this particular problem, there is an odd dependence on batch normalization that I hope someone may recognize here. Basically, in order to see any convergence to non-trivial solutions for the reinforcement learning problem, you have to have batchnorm train mode enabled with an extremely high (> .9) momentum (basically, you need to not average too many batch statistics together). However, if you switch to eval mode on the train set right after training with train mode, it usually fails entirely. Rather, you have to run train mode without any learning for a few batches first before switching to eval mode (after training with learning), and then everything works fine. If you try to train your model purely on eval mode the entire time (without ever enabling train mode), the model fails to converge to a non-trivial solution.

My interpretation of this (which is based purely on speculation) is that for whatever reason the rewards are causing great updates to the batch statistics in way that averaging them out would be bad. I don’t really know how to test this hypothesis however, and I’m not sure what the solution is to it. Any help or thoughts would be greatly appreciated.

1 Like

Hi, bit of a late response, but can’t help myself :slight_smile:

From what I have been reading lately, your hypothesis is correct, policy gradients generally get hurt by batchnorm. As I understand it, figuring out a way to normalize deep policy networks under batchnorm is an open research problem.

If you look at a lot of the RL research, a lot of it seems to be about normalizing the rewards and value functions so that they don’t cause overly large changes to the policy or value function, so it kinda makes sense that adding on normalization that was designed mostly for supervised learning could cause issues.

1 Like