Gradient Checkpointing does not reduce memory usage

Hi all,

I’m trying to train a model on my GPU (RTX 2080 super) using Gradient Checkpointing in order to significantly reduce the usage of VRAM. I’m using torch.utils.checkpoint.checkpoint.

The model in which I want to apply it is a simple CNN with a flatten layer at the end. Although I think I applied it right I’m not having any memory usage reduction. The memory usage with Gradient Checkpointing is the same as without it, however I do see a increase in the time per epoch (something expected given the nature of Gradient Checkpointing).

I’ve been searching through the forums but I didn’t see anything related with this.

The model definition is the following:

class my_model(torch.nn.Module):

    def __init__(self):
        super(my_model, self).__init__()

        # Input shape [19, 22, 20] -> Output shape [3258]
        self.conv1 = torch.nn.Conv2d(in_channels = 20,
                                     out_channels = 50,
                                     kernel_size = 3,
                                     padding = 1)

        self.conv2 = torch.nn.Conv2d(in_channels = 50,
                                     out_channels = 25,
                                     kernel_size = 3,
                                     padding = 1)

        self.conv3 = torch.nn.Conv2d(in_channels = 25,
                                     out_channels = 10,
                                     kernel_size = 3,
                                     padding = 1)

        self.fc1 = torch.nn.Linear(in_features = 19 * 22 * 10,
                                   out_features = 3258)

    def flat(self, x):
        return torch.flatten(x, 1)

    def forward(self, x):

        x = self.conv1(x)
        x = F.relu(x)

        x = checkpoint(self.conv2, x)
        x = checkpoint(F.relu, x)

        x = checkpoint(self.conv3, x)
        x = checkpoint(F.relu, x)

        x = checkpoint(self.flat, x)
        x = checkpoint(self.fc1, x)

        return x

I understand that when you checkpoint some operation the activations of that operation are not saved in memory (it’s recomputed when needed by the backward operation).

For show purposes I define some random train and validation data:

class MyDataset(Dataset):

    def __init__(self, data, target, transform = None):
        self.data = torch.from_numpy(data).float()
        self.target = torch.from_numpy(target).float()
        self.transform = transform

    def __getitem__(self, index):
        x = self.data[index]
        y = self.target[index]

        if self.transform:
             x = self.transform(x)

        return x, y

    def __len__(self):
        return len(self.data)

# Random data
x = np.random.randn(8766, 20, 19, 22) # 19 x 22 x 20 array
y = np.random.randn(8766, 3258) # 3258 vector

# Generate train and validation indices
validation_split = 0.1
idxs = list(range(x.shape[0]))
np.random.shuffle(idxs)

split = int(np.floor(validation_split * len(idxs)))
train_idxs, val_idxs = idxs[split:], idxs[:split]

Train and Validation DataLoaders:

# Train dataloader
dataset_train = MyDataset(data = x[train_idxs, :, :, :],
                          target = y[train_idxs, :])
loader_train = DataLoader(dataset_train,
                          batch_size = 100,
                          shuffle = True,
                          num_workers = 1)

# Validation dataloader
dataset_val = MyDataset(data = x[val_idxs, :, :, :],
                        target = y[val_idxs, :])
loader_val = DataLoader(dataset_val,
                        batch_size = 100,
                        shuffle = True,
                        num_workers = 1)

And I finally I train the model:

# Instanciate the model
model = my_model.cuda()

# Initialize lowest loss value registered (validation)
lowest_loss_val = 10000

# Initialize the number of epochs without decreasement in loss (validation)
epochs_loss_val = 1

# List for the per epoch losses in train
avg_train_loss_list = []

# List for the per epoch losses in validation
avg_val_loss_list = []

# Model hyperparameters
epochs = 100000
learning_rate = 0.0001
patience = 30 # Early stopping patience

# Loss function and optimizer
loss_function = torch.nn.MSELoss()
optimizer = optim.Adam(model.parameters(),
                       lr = learning_rate)

# Iterate over epochs
for epoch in range(epochs):

    # Record start time of epoch
    x1 = time.time()

    # Training step
    model.train()
    loss_train_list = []

    for batch_idx, (data, target) in enumerate(loader_train):

        # Clean gradients value
        optimizer.zero_grad()

        # Forward pass
        output = model(data.cuda())

        # Loss calculation
        loss = loss_function(output, target.cuda())

        # Append batch loss
        loss_train_list.append(loss.item())

        # Gradients calculation
        loss.backward()

        # Optimizer update
        optimizer.step()

    # Calculate total time spent on epoch
    x2 = time.time() - x1

    # Training loss averaged over batches (loss per epoch)
    avg_train_loss_list.append(np.average(loss_train_list))

    # Validation step (earlystopping purposes)
    model.eval()
    loss_val_list = []

    for batch_idx, (data, target) in enumerate(loader_val):

        # Forward pass
        output = model(data.cuda())

        # Loss calculation
        loss = loss_function(output, target.cuda())

        # Append batch loss
        loss_val_list.append(loss.item())

    # Validation loss averaged over batches
    avg_val_loss_list.append(np.average(loss_val_list))

    # Log info
    print('Epoch {} | Training Loss {} | Validation Loss {} | Elapsed seconds {}'.format(epoch + 1,
                                                                    round(avg_train_loss_list[-1], 5),
                                                                    round(avg_val_loss_list[-1], 5),
                                                                    round(x2, 5)))

    # Earlystopping checking
    # If latest loss is lowest than the lower one
    if avg_val_loss_list[-1] < lowest_loss_val:

        lowest_loss_val = avg_val_loss_list[-1]
        epochs_loss_val = 1

        # Save model
        torch.save(model.state_dict(), 'model.pt')

    # If we've been without losses decreasement for patience epochs
    elif epochs_loss_val > patience:

        print('{} epochs without loss reduction. Finishing training...'.format(epochs_loss_val - 1))
        print('Final loss: {}'.format(round(lowest_loss_val, 5)))
        break # Break training loop

    epochs_loss_val += 1

I’ve been trying to find without success the reason why this is not behaving as expected. Is my implementation wrong? Can’t my model’s architecture benefit from Gradient Checkpointing? I would appreciate any help.

I’m running the code on Pytorch 1.31

Thank you in advance!!

Hi,

The point of checkpointing is to take multiple operations and remove the buffers between them.
If you checkpoint a single operation (like just a relu), it will have no effect but slow down your process.

You want to either checkpoint a conv + relu block? Or even all yours convs together?

The models I’m trying to train doesn’t fit on my GPU memory. I’ve already implemented Gradient Accumulation and FP16 training but it’s not enough, that’s why I’m using Gradient Checkpointing. I don’t care about computation time, I just want to fit my model on memory.

I tried checkpointing the conv + relu operations:

class my_model(torch.nn.Module):

    def __init__(self):
        super(CNN10_GC, self).__init__()

        # Input shape [19, 22, 20] -> Output shape [3258]
        self.conv1 = torch.nn.Conv2d(in_channels = 20,
                                     out_channels = 50,
                                     kernel_size = 3,
                                     padding = 1)

        self.conv2 = torch.nn.Conv2d(in_channels = 50,
                                     out_channels = 25,
                                     kernel_size = 3,
                                     padding = 1)

        self.conv3 = torch.nn.Conv2d(in_channels = 25,
                                     out_channels = 10,
                                     kernel_size = 3,
                                     padding = 1)

        self.fc1 = torch.nn.Linear(in_features = 19 * 22 * 10,
                                   out_features = 3258)

    def flat(self, x):
        return torch.flatten(x, 1)

    def block1(self, x):
        return F.relu(self.conv2(x))

    def block2(self, x):
        return F.relu(self.conv3(x))

    def block3(self, x):
        return self.fc1(self.flat(x))

    def forward(self, x):

        x = self.conv1(x)
        x = F.relu(x)

        x = checkpoint(self.block1, x)

        x = checkpoint(self.block2, x)

        x = checkpoint(self.block3, x)

        return x

And even all the convs together:

class my_model(torch.nn.Module):

    def __init__(self):
        super(CNN10_GC, self).__init__()

        # Input shape [19, 22, 20] -> Output shape [3258]
        self.conv1 = torch.nn.Conv2d(in_channels = 20,
                                     out_channels = 50,
                                     kernel_size = 3,
                                     padding = 1)

        self.conv2 = torch.nn.Conv2d(in_channels = 50,
                                     out_channels = 25,
                                     kernel_size = 3,
                                     padding = 1)

        self.conv3 = torch.nn.Conv2d(in_channels = 25,
                                     out_channels = 10,
                                     kernel_size = 3,
                                     padding = 1)

        self.fc1 = torch.nn.Linear(in_features = 19 * 22 * 10,
                                   out_features = 3258)

    def flat(self, x):
        return torch.flatten(x, 1)

    def block1(self, x):

        x = self.conv2(x)
        x = F.relu(x)

        x = self.conv3(x)
        x = F.relu(x)

        return x

    def forward(self, x):

        x = self.conv1(x)
        x = F.relu(x)

        x = checkpoint(self.block1, x)

        x = self.flat(x)
        x = self.fc1(x)

        return x

But the memory usage is exactly the same, however as I said I notice a little time increase per epoch.

I was thinking that maybe the activations not calculated when using the checkpointing function were not big enough so the memory saving was not meaningful. I checked this by setting the requires_grad attribute of the input tensor in the forward method to True but I did see a increase in memory, so my theory was wrong.

I have not idea what is failing in my models… appreciate any help!

Hi,

If your images are of size [19, 22], then I would expect the intermediary state to be very small indeed. Especially for a batch size of 100 (if the sample in your first post is accurate).
For example, the output of the conv1 is <2MB (100 * 50 * 19 * 22).

How much memory does your GPU has?

Hi,

My GPU (RTX 2080 Super) has 8 GB of memory. Without applying any memory optimization technique it uses 1317 MiB, with Gradient Accumulation (batch size of 100 with batches of 1 element for the accumulation) uses 1097 MB and with FP16 training (using half() method) uses 987 MB. There is no decrease with Gradient Checkpointing.

You can see how this model fit on memory however a variation of it does not (the one I’m really interested in fitting on memory). This variation follows a similar architecture:

class my_model_full(torch.nn.Module):

    def __init__(self):
        super(my_model_full , self).__init__()

        # Input shape [45, 43, 20]
        self.conv1 = torch.nn.Conv2d(in_channels = 20,
                                     out_channels = 50,
                                     kernel_size = 3,
                                     padding = 1)

        self.conv2 = torch.nn.Conv2d(in_channels = 50,
                                     out_channels = 25,
                                     kernel_size = 3,
                                     padding = 1)

        self.conv3 = torch.nn.Conv2d(in_channels = 25,
                                     out_channels = 10,
                                     kernel_size = 3,
                                     padding = 1)

        self.param1 = torch.nn.Linear(in_features = 45 * 43 * 10,
                                      out_features = 29913)

        self.param2 = torch.nn.Linear(in_features = 45 * 43 * 10,
                                      out_features = 29913)

        self.param3 = torch.nn.Linear(in_features = 45 * 43 * 10,
                                      out_features = 29913)

    def flat(self, x):
        return torch.flatten(x, 1)

    def cat_torch(self, x1, x2, x3):
        return torch.cat((x1, x2, x3), dim = 1)

    def forward(self, x):

        x.requires_grad = True

        x = checkpoint(self.conv1, x)
        x = checkpoint(F.relu, x)

        x = checkpoint(self.conv2, x)
        x = checkpoint(F.relu, x)

        x = checkpoint(self.conv3, x)
        x = checkpoint(F.relu, x)

        x = checkpoint(self.flat, x)

        x1 = checkpoint(self.param1, x)
        x1 = checkpoint(torch.sigmoid, x1)

        x2 = checkpoint(self.param2, x)

        x3 = checkpoint(self.param3, x)

        x = checkpoint(self.cat_torch, x1, x2, x3)

        return x

Casually today I had access to a Tesla V100 with 32 GB of memory. Without any memory optimization technique I receive a OOM error, same as with Gradient Checkpointing. With Gradient Accumulation uses 32027 MB and with FP16 training uses 16943 MB. As you can see I finally achieved to fit the model on memory with Gradient Accumulation (in exchange of a big increase in time per epoch) but I still dont’have any memory reduction with Gradient Checkpointing…

I’m still curious about this issue. Appreciate any help!

As mentionned above, checkpointing a single layer will have no effect. This is expected.

You will need to checkpoint multiple layers at once to allow the checkpointing to reduce the number of buffers saved by the network.

Note that in this new network, you linear layers are huge: each will take 800MB just to store the weights + 800MB for their gradients. So that might explain why your model uses so much memory.
You can try to checkpoint the three self.param and the cat together to remove the largest biffers. But the buffers in this net are very small compared to the weights, so the difference will be quite small.

Also on GPU, we use a custom allocator that does not return the memory to the OS when it is released. So you will need to use torch.cuda.memory_allocated() to check how much memory is actually used. As nvidia-smi will return a wrong number.

I implemented the Checkpointing as you said and, after doing some tests, I saw a little (really little) decrease in GPU memory (for this purpose I used some funcions from torch.cuda).

When I started to implement the Checkpointing I thought I would obtain a more signficant differences on GPU memory usage, I suposse that my architecture does not fit well in this framework.

Thank you for the help!