How to use have batch norm not forget batch statistics it just used?

related: How does pytorch’s batch norm know if the forward pass its doing is for inference or training?