No effect / increased memory use when using checkpoint_sequential

Hi all,

I was toying around with checkpoint_sequential and didn’t notice any substantial different (and if a tall, a slight 1-2% increase in memory use).

a) Before moving onto my more complicated models, I tried this with AlexNet and VGG-16 (maybe those are too small to see an effect?). Below is the shorter AlexNet as an example:

class AlexNet(nn.Module):

    def __init__(self, num_classes):
        super(AlexNet, self).__init__()

        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes)
        )
        
        self.feature_modules = [module for k, module in self.features._modules.items()]
        self.classifier_modules = [module for k, module in self.classifier._modules.items()]

    def forward(self, x):
        
        x.requires_grad = True
        
        x = checkpoint_sequential(functions=self.feature_modules, 
                                  segments=2, 
                                  input=x)
        
        x = self.avgpool(x)
        x = x.view(x.size(0), 256 * 6 * 6)
        logits = checkpoint_sequential(functions=self.classifier_modules, 
                                       segments=2, 
                                       input=x)
        
        return logits

Is there something weird about how I am using checkpoint_sequential? The code example in torch.utils.checkpoint — PyTorch 1.7.0 documentation shows

>>> model = nn.Sequential(...)
>>> input_var = checkpoint_sequential(model, chunks, input_var)

which is essentially what I am doing above.

PS: I was monitoring max memory use with nvidia-smi with and without checkpoint sequential.

I also tried

import tracemalloc


tracemalloc.start()

# train code

current, peak =  tracemalloc.get_traced_memory()
print(f"{current:0.2f} Gb, {peak:0.2f} Gb")
tracemalloc.stop()

for the scenario with and without checkpoint_sequential. The case with checkpoint_sequential seems to consume slightly more (not less) memory, which is odd!?

EDIT: I added a full example with the code above here: deeplearning-models/gradient-checkpointing-alexnet.ipynb at master · rasbt/deeplearning-models · GitHub

EDIT2: Reducing the number of segments to 1 improved memory efficiency by 20%.