Use of torch.utils.checkpoint.checkpoint causes simple model to diverge

I have defined a simple model that trains on the CIFAR10 dataset included in torchvision. This model converges to a loss of ~0.72 after 10 epochs of training. Here is the model definition:

class CIFAR10Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.cnn_block_1 = nn.Sequential(*[
            nn.Conv2d(3, 32, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Dropout(0.25)
        ])
        self.cnn_block_2 = nn.Sequential(*[
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Dropout(0.25)
        ])
        self.flatten = lambda inp: torch.flatten(inp, 1)
        self.head = nn.Sequential(*[
            nn.Linear(64 * 8 * 8, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 10)
        ])
    
    def forward(self, X):
        X = self.cnn_block_1(X)
        X = self.cnn_block_2(X)
        X = self.flatten(X)
        X = self.head(X)
        return X

I then defined and trained a new version of this model, which has been updated to use torch.utils.checkpoint.checkpoint. This model has the following definition (updates were made to exclude nn.Dropout from checkpointing, since Dropout is incompatible with checkpointing):

class CIFAR10Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.cnn_block_1 = nn.Sequential(*[
            nn.Conv2d(3, 32, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        ])
        self.dropout_1 = nn.Dropout(0.25)
        self.cnn_block_2 = nn.Sequential(*[
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        ])
        self.dropout_2 = nn.Dropout(0.25)
        self.flatten = lambda inp: torch.flatten(inp, 1)
        self.linearize = nn.Sequential(*[
            nn.Linear(64 * 8 * 8, 512),
            nn.ReLU()
        ])
        self.dropout_3 = nn.Dropout(0.5)
        self.out = nn.Linear(512, 10)
    
    def forward(self, X):
        X = torch.utils.checkpoint.checkpoint(self.cnn_block_1, X)
        X = self.dropout_1(X)
        X = torch.utils.checkpoint.checkpoint(self.cnn_block_2, X)
        X = self.dropout_2(X)
        X = self.flatten(X)
        X = self.linearize(X)
        X = self.dropout_3(X)
        X = self.out(X)
        return X

However, this model achieves radically worse performance than its uncheckpointed peer: just ~1.59 loss after 10 epochs!

I don’t understand why this is the case. This is my first time using this API; perhaps someone here who understands it better than I can point out where my error is?

For reference, here is the full code and logs: uncheckpointed version, checkpointed version.

1 Like

The PyTorch autograd docs state:

If there’s a single input to an operation that requires gradient, its output will also require gradient. Conversely, only if all inputs don’t require gradient, the output also won’t require it.

The input to a model like this one will be a vector with no_grad=False (since we are performing gradient descent relative to the model weights, not relative to the values of the input sample itself). Unless you manually freeze it, any module you apply to that vector containing learnable parameters (e.g. nn.Conv2d or nn.Linear, but not nn.Dropout or nn.MaxPool2d) will upgrade it to vector with no_grad=True automatically.

Applying checkpointing will not do this for you. Instead, the way that checkpointing is implemented, whether the output vector produced by the checkpointed module is no_grad=True (and thus subject to gradient updates) or no_grad=False (and thus frozen) is solely determined by whether the input vector is no_grad=True or no_grad=False.

This has the side effect (?) that the first block in the model cannot be checkpointed. E.g. in this model the input to self.cnn_block_1 is the input vector, which has no_grad=False; so setting X = torch.utils.checkpoint.checkpoint(self.cnn_block_1, X) causes every gradient in cnn_block_1 to stay fixed.

I don’t think I would have been able to figure this out on my own; I got to this conclusion by Googling around and stumbling across this article, which mentions this issue in the footnotes. Seems like something that should really be added to the docs.

Removing the checkpointing on self.cnn_block_1 fixed this problem and produced a model that converged correctly again.

6 Likes

Hopefully the next person that runs into this issue will find this thread and not be so confused. :slight_smile:

5 Likes