Questions about batch normalization layer in EfficientNet (timm) with gradient checkpointing


I have been reading the code of EfficietNet from timm, and it has gradient checkpointing supported. From the tutorial:, what I understand is that we need to deal with batch normalization layers. Therefore, I have two questions as follows:

  1. If the number of checkpoints is n, does it mean the runing_mean has to be updated n-th times, with the momentum value of nth_root(m)
  2. How does timm’s EfficietNet deal with batch normalization layers when gradient checkpointing is enabled?