How much memory i should get from torch.utils.checkpoint?

In this tutorial
Batch size increased from 8 to 32
from 24 to 132

On mine model I got only x2 increase from 32 to 64.
Is this ok or i doing wrong something?
How can i check if checkpoint work as intended?

Model is transformer with rnn layers.

1 Like

If you see a memory reduction and an increased computation cost, then checkpointing should work correctly.
The memory saving might depend where you put the checkpoints into your model and cannot be generalized, if Iā€™m not mistaken.

1 Like

I tried to put checkpoint around transformer layer
like this

    for layer in self.layers:
        output, self_attn, encoder_attn = checkpoint(
            layer, output, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask)

and around encoder/decoder modules like this

        mel_outputs, gate_outputs = checkpoint(
            self.decoder, mels_input, encoder_outputs, mels_attn_len_mask, texts_attn_len_mask)
        decoder_self_attn, decoder_encoder_attn = None, None
        mel_outputs, gate_outputs, decoder_self_attn, decoder_encoder_attn = self.decoder(
            mels_input, encoder_outputs, mels_attn_len_mask, texts_attn_len_mask)

With second approach, while more layers covered by checkpoint, model fail with 64 batch size.

Also I tried memory profiling, is it ok to have so much memoru use on backward?
Maybe im doing wrong something?

335    15.90M       18.00M    0.00B    0.00B      for epoch in range(epoch_offset, hparams.epochs):
346   180.23M      246.00M  164.34M  228.00M          for i, batch in enumerate(train_loader):
347   182.84M      310.00M    2.60M   64.00M              x, y = parse_batch(batch)
348   433.07M        3.05G  250.23M    2.75G              y_pred = model(x, y)
349   404.03M      532.00M  -29.04M   -2.53G              loss, losses = criterion(x, y, y_pred)
361   404.03M      500.00M    0.00B    0.00B              if hparams.fp16_run:
362                                                           with amp.scale_loss(loss, optimizer) as scaled_loss:
363                                                               scaled_loss.backward()
364                                                       else:
365    14.43G       14.64G   14.03G   14.15G                  loss.backward()

Its turned out, im not sure why, batch size have great impact on model:
Batch size 32 64 and 76
So i want to maximize batch size with checkpointing, but cannot find guids how to do it right.

1 Like