Checkpoint_sequential for DataParallel

I’ve been trying to use checkpoint_sequential for DataParallel, here is my solution:

out= checkpoint_sequential(model.module, num_checkpoints, batch_data)

  • I basically obtain the nn.Sequential object inside the DataParallel, and apply the checkpoints to that. It is also one of those things that is difficult to check if it actually works and applies checkpoints onto all of the GPUs, would appreciate some feedback :slight_smile: