Checkpoint with BatchNorm running averages

I think this is right. pytorch_memonger/tutorial/Checkpointing_for_PyTorch_models.ipynb at master · prigoyal/pytorch_memonger · GitHub

We need to compute the batch normal stats using sqrt instead if checkpointing is on.