MLP Activation checkpointing

Hi,

I am trying to learn how to use activation checkpointing provided by torch.utils.checkpoint and torch.distributed.algorithms._checkpoint.checkpoint_wrapper. So I have a script to apply checkpointing to a feedforward network to profile the memory usage at different steps:

import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.utils.checkpoint as checkpoint

from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
        apply_activation_checkpointing,
        checkpoint_wrapper,
        CheckpointImpl,
)

class FeedForward(nn.Module):
    def __init__(self, hidden_size, intermediate_size, num_layers):
        super(FeedForward, self).__init__()
        self.layers = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_size, intermediate_size),
                nn.GELU(),
                nn.Linear(intermediate_size, hidden_size)
            ) for _ in range(num_layers)
        ])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

print(torch.__version__)

hidden_size = 1024
intermediate_size = 4096
num_layers = 8
sequence_length = 1024
batch_size = 128

# Baseline model without checkpointing
torch.cuda.reset_peak_memory_stats()
with torch.autograd.profiler.profile(profile_memory=True, use_cuda=True) as prof:
    print("model with checkpointing wrapper")
    model = FeedForward(hidden_size, intermediate_size, num_layers).cuda()
    optimizer = torch.optim.Adam(model.parameters())
    model = checkpoint_wrapper(model) # comment this line to disable checkpointing
    model.train()

    print("Peak memory usage after model init: ", torch.cuda.max_memory_allocated() / 1024 / 1024, "MB")
    input_data = Variable(torch.randn(batch_size, sequence_length, hidden_size), requires_grad=True).cuda()
    print("Peak memory usage after data init: ", torch.cuda.max_memory_allocated() / 1024 / 1024, "MB")
    output = model(input_data)
    print("Peak memory usage after forward: ", torch.cuda.max_memory_allocated() / 1024 / 1024, "MB")
    optimizer.zero_grad()
    output.sum().backward()
    optimizer.step()
    print("Peak memory usage after backward: ", torch.cuda.max_memory_allocated() / 1024 / 1024, "MB")

and I got output when I enable the checkpointing

Peak memory usage after model init:  256.15625 MB
Peak memory usage after data init:  768.15625 MB
Peak memory usage after forward:  5384.28125 MB
Peak memory usage after backward:  40736.4072265625 MB

when I disable the checkpointing:

Peak memory usage after model init:  256.15625 MB
Peak memory usage after data init:  768.15625 MB
Peak memory usage after forward:  37641.28125 MB
Peak memory usage after backward:  40224.4072265625 MB

where I see the peak memory usage after backward is even higher when checkpointing enabled!

How should I fix it?