I notice the default stager with pinned and shared memory in main branch pytorch/torch/distributed/checkpoint/staging.py at main · pytorch/pytorch · GitHub can staging model’s state_dict asynchronously, how does the stager handle layers who’s state will be modified during forward, like BatchNorm (its buffer is changed during fwd process).
torchtitan zero overhead checkpoint reference: [Distributed w/ TorchTitan] Optimizing Checkpointing Efficiency with PyTorch DCP