DDP and Gradient checkpointing

Hi everyone,
I tried to use torch.utils.checkpoint along with DDP. However, after the first iteration, the program hanged. I read one thread last year in the forum and a person said that DDP and checkpointing havent worked together yet. Is that true? Any suggestions for my case? Thank you.

Hi,

I am afraid this is true.
We are working on a solution for in 1.10.

We currently have a prototype API _set_static_graph which can be applied to DDP if your training is static across all iterations (i.e. there is no conditional execution in the model). Documentation: pytorch/distributed.py at master · pytorch/pytorch · GitHub.

With static graph training, DDP will record the # of times parameters expect to get gradient and memorize this, which solves the issue around activation checkpointing and should make it work.

1 Like

I don’t understand what the issue is. Why did your code hang - that is essential information to put in here. Did you try any of the following:

if none of them worked can you provide more details? In particular Your original post does not describe enough to know what the problem is. Things can hang for many reasons - especially in complicated multip processing code.

Hi @albanD, did you find a solution to this in 1.10.0.