Using Torch.utils.checkpoint.checkpoint with DataParallel

I’m trying to figure out how to use Torch.utils.checkpoint.checkpoint to save memory consumption with multi-gpu training (nn.DataParallel or nn.parallel.DistributedDataParallel)

But it seem that using checkpointing and multi-gpu training simultaneously greatly slows down training speed. See:


I’ve found an unofficial implementation of gradient checkpointing for DenseNet which works fine for multi-gpu: But I wonder is there any efficient way to use checkpointing and multi-gpu training simultaneously in the official pytorch implementation.


DistributedDataParallel currently does not work with checkpoint, see this discussion:

This PR ( fixes the error, but as under the new delay_allreduce mode, communication and computation no longer overlap, it is expected to be slower.

DataParallel might work with checkpoint, have you tried that?