Currently, saving checkpoints synchronously will block training greatly in LLM situations. We’re in need of an asynchronous checkpoint saving feature.
Projects like JAX(Save and load checkpoints), PyTorch Lightning(Distributed checkpoints (expert) — PyTorch Lightning 2.0.7 documentation), and Microsoft Nebula have already implemented such feature.
It doesn’t seem overly complex, and I’m wondering if the PyTorch community has any plans to introduce this feature. Could you provide some insights, or let me know if I’m missing something?