Hi all,
I was toying around with checkpoint_sequential and didn’t notice any substantial different (and if a tall, a slight 1-2% increase in memory use).
a) Before moving onto my more complicated models, I tried this with AlexNet and VGG-16 (maybe those are too small to see an effect?). Below is the shorter AlexNet as an example:
class AlexNet(nn.Module):
def __init__(self, num_classes):
super(AlexNet, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(64, 192, kernel_size=5, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(192, 384, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(384, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
)
self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
self.classifier = nn.Sequential(
nn.Dropout(0.5),
nn.Linear(256 * 6 * 6, 4096),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Linear(4096, 4096),
nn.ReLU(inplace=True),
nn.Linear(4096, num_classes)
)
self.feature_modules = [module for k, module in self.features._modules.items()]
self.classifier_modules = [module for k, module in self.classifier._modules.items()]
def forward(self, x):
x.requires_grad = True
x = checkpoint_sequential(functions=self.feature_modules,
segments=2,
input=x)
x = self.avgpool(x)
x = x.view(x.size(0), 256 * 6 * 6)
logits = checkpoint_sequential(functions=self.classifier_modules,
segments=2,
input=x)
return logits
Is there something weird about how I am using checkpoint_sequential
? The code example in torch.utils.checkpoint — PyTorch 1.7.0 documentation shows
>>> model = nn.Sequential(...)
>>> input_var = checkpoint_sequential(model, chunks, input_var)
which is essentially what I am doing above.
PS: I was monitoring max memory use with nvidia-smi with and without checkpoint sequential.
I also tried
import tracemalloc
tracemalloc.start()
# train code
current, peak = tracemalloc.get_traced_memory()
print(f"{current:0.2f} Gb, {peak:0.2f} Gb")
tracemalloc.stop()
for the scenario with and without checkpoint_sequential
. The case with checkpoint_sequential
seems to consume slightly more (not less) memory, which is odd!?
EDIT: I added a full example with the code above here: deeplearning-models/gradient-checkpointing-alexnet.ipynb at master · rasbt/deeplearning-models · GitHub
EDIT2: Reducing the number of segments to 1 improved memory efficiency by 20%.