Hi everyone,
I tried to use torch.utils.checkpoint along with DDP. However, after the first iteration, the program hanged. I read one thread last year in the forum and a person said that DDP and checkpointing havent worked together yet. Is that true? Any suggestions for my case? Thank you.
Hi,
I am afraid this is true.
We are working on a solution for in 1.10.
We currently have a prototype API _set_static_graph
which can be applied to DDP if your training is static across all iterations (i.e. there is no conditional execution in the model). Documentation: pytorch/distributed.py at master · pytorch/pytorch · GitHub.
With static graph training, DDP will record the # of times parameters expect to get gradient and memorize this, which solves the issue around activation checkpointing and should make it work.
I don’t understand what the issue is. Why did your code hang - that is essential information to put in here. Did you try any of the following:
- Getting Started with Distributed Data Parallel — PyTorch Tutorials 1.10.1+cu102 documentation
- Checkpointing DDP.module instead of DDP itself - #2 by mrshenli
- Checkpointing DDP.module instead of DDP itself - #3 by Brando_Miranda
if none of them worked can you provide more details? In particular Your original post does not describe enough to know what the problem is. Things can hang for many reasons - especially in complicated multip processing code.
Hi @albanD, did you find a solution to this in 1.10.0.