Using Torch.utils.checkpoint.checkpoint with DataParallel

Hi,
I’m trying to figure out how to use Torch.utils.checkpoint.checkpoint to save memory consumption with multi-gpu training (nn.DataParallel or nn.parallel.DistributedDataParallel)

But it seem that using checkpointing and multi-gpu training simultaneously greatly slows down training speed. See:

and

I’ve found an unofficial implementation of gradient checkpointing for DenseNet which works fine for multi-gpu: GitHub - csrhddlam/pytorch-checkpoint. But I wonder is there any efficient way to use checkpointing and multi-gpu training simultaneously in the official pytorch implementation.

Thanks!

DistributedDataParallel currently does not work with checkpoint, see this discussion: https://github.com/pytorch/pytorch/issues/24005

This PR (https://github.com/pytorch/pytorch/pull/35083) fixes the error, but as under the new delay_allreduce mode, communication and computation no longer overlap, it is expected to be slower.

DataParallel might work with checkpoint, have you tried that?