Questions about batch normalization layer in EfficientNet (timm) with gradient checkpointing

Hi,

I have been reading the code of EfficietNet from timm, and it has gradient checkpointing supported. From the tutorial: https://github.com/prigoyal/pytorch_memonger/blob/master/tutorial/Checkpointing_for_PyTorch_models.ipynb, what I understand is that we need to deal with batch normalization layers. Therefore, I have two questions as follows:

  1. If the number of checkpoints is n, does it mean the runing_mean has to be updated n-th times, with the momentum value of nth_root(m)
  2. How does timm’s EfficietNet deal with batch normalization layers when gradient checkpointing is enabled?