Torch.utils.checkpoint.checkpoint

Hi,

Can someone demonstrate how to use the new checkpoint feature for part of the model. For example how to use checkpoint for the _DenseBlock of the densenet implementation, since this layer stores intermediate feature maps and becomes memory intensive?

Should the checkpoint be called in the __init__ of the module or in forward?

Pointers would be really helpful.

Regards
Nabarun

3 Likes

Came here to ask the same thing.
And the use of the term ‘checkpoint’ as ‘trading compute for memory’ is confusing me. In common CS parlance, ‘checkpointing’ refers to the practice of saving a program’s state so that it can be resumed if failure occurs. But I don’t see any specification of a file path (for saving) in the torch.utils.checkpoint.checkpoint spec.

Hi,

You need to modify the forward pass to replace how you use the corresponding submodule:

# Original:
out = self.my_block(inp1. inp2. inp3)

# With checkpointing:
out = checkpoint(self.my_block, inp1, inp2, inp3)
3 Likes

Hi,

If I understand correctly, the intermediate activations will be stored for the layers which we checkpoint and for the others will be recalculated at the backward pass. So basically we need to checkpoint the block which we want the activations to be stored and not the block for which we dont want to store. Is my understanding correct?

Regards

No it’s the oposite.
By default, the autograd saves all intermediary results needed for the backward pass.
The checkpointing tool has been added to allow not storing all intermediary results. So you only need it if you don’t have enough memory to store your model, otherwise it will just slow down your model. So for the function/module that is given to it, no intermediary result will be saved and they will be recomputed during the backward.

5 Likes

I see. Thanks for the explanation. And thanks for the very useful tool! :grinning:

1 Like

Hi,
Parameters

function – describes what to run in the forward pass of the model or part of the model. 
It should also know how to handle the inputs passed as the tuple.

For the function parameter,
I have a confusion – It is forward function or part of forward function ?

In other words,

out is the return output of the forward function or intermediate result of the forward function.
Right?