Use of torch.utils.checkpoint.checkpoint causes simple model to diverge

The PyTorch autograd docs state:

If there’s a single input to an operation that requires gradient, its output will also require gradient. Conversely, only if all inputs don’t require gradient, the output also won’t require it.

The input to a model like this one will be a vector with no_grad=False (since we are performing gradient descent relative to the model weights, not relative to the values of the input sample itself). Unless you manually freeze it, any module you apply to that vector containing learnable parameters (e.g. nn.Conv2d or nn.Linear, but not nn.Dropout or nn.MaxPool2d) will upgrade it to vector with no_grad=True automatically.

Applying checkpointing will not do this for you. Instead, the way that checkpointing is implemented, whether the output vector produced by the checkpointed module is no_grad=True (and thus subject to gradient updates) or no_grad=False (and thus frozen) is solely determined by whether the input vector is no_grad=True or no_grad=False.

This has the side effect (?) that the first block in the model cannot be checkpointed. E.g. in this model the input to self.cnn_block_1 is the input vector, which has no_grad=False; so setting X = torch.utils.checkpoint.checkpoint(self.cnn_block_1, X) causes every gradient in cnn_block_1 to stay fixed.

I don’t think I would have been able to figure this out on my own; I got to this conclusion by Googling around and stumbling across this article, which mentions this issue in the footnotes. Seems like something that should really be added to the docs.

Removing the checkpointing on self.cnn_block_1 fixed this problem and produced a model that converged correctly again.

8 Likes