Gradient checkpointing not saving memory

Hi,

I implemented a triton kernel that is called inside a torch.autograd.Function. When using it in my training, I got an OOM. This led me to believe that activation checkpointing doesn’t work with torch.autograd.Function. I guess when I explicitly save tensors to the ctx they actually get saved and torch.utils.checkpoint doesn’t do anything about that. Is that correct? If so, how should I adapt my code to also support checkpointing?

Also, I played around with the checkpoint function using this script on one V100 GPU (32GB of VRAM).

import torch
from torch.nn import Linear
from torch.utils.checkpoint import checkpoint
from torch.autograd import Function

class LinearFun(Function):
    @staticmethod
    def forward(ctx, inp, weight):
        ctx.save_for_backward(inp, weight)
        return inp @ weight.T
    @staticmethod
    def backward(ctx, grad_output):
        inp, weight = ctx.saved_tensors

        grad_w = grad_output.T @ inp
        grad_inp = grad_output @ weight
        return grad_inp, grad_w

class NewLinear(Linear):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def forward(self, inp):
        return LinearFun.apply(inp, self.weight)

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = Linear(1024, 1024)
        self.l2 = Linear(1024, 1024)
        self.l3 = Linear(1024, 1)
        # self.l1 = NewLinear(1024, 1024)
        # self.l2 = NewLinear(1024, 1024)
        # self.l3 = NewLinear(1024, 1)
    
    def forward(self, inp):
        inp = self.l1(inp)
        inp = self.l2(inp)
        inp = self.l3(inp)
        return inp


model = Model().cuda()
inp = torch.randn(int(1e6), 1024, device="cuda")

# <---- USE CHECKPOINTING
out = checkpoint(model, inp, use_reentrant=False)
# <---- NOT USING CHECKPOINTING
# out = model(inp)

loss = out.sum()
print('========= before backward =========')
loss.backward()
print('========= after backward =========')

I tweaked the number of inputs to int(1e6) since that brings both scenarios to an OOM. But I don’t see that the checkpointing allows me to got to higher input dimensions.
Am I missing something?

Your autograd custom Function code looks fine. Saving to ctx via ctx.save_for_backward(*tensors) should be what you want.

However, when applying checkpointing, you do not want to apply it to the entire model, because during backward all the tensors intended to be saved for backward in that region would end up being materialized at once anyway.

To avoid the peak memory spike that you are still observing, you’ll want to apply checkpoint piece by piece so that during backward rematerialization also occurs piece by piece.

Ok, so then in this example, there is actually no benefit you get from activation checkpointing? Or how would you re-write this so that the peak memory consumption is reduced?

Also, I was wondering whether torch will avoid saving the tensors to the context of the function when gradient checkpointing is enabled.

Yes, checkpoint can interpose logic to avoid saving when you use the ctx.save_for_backward API

Ok, nice. Somewhat unrelated, but when calling ctx.save_for_backward, it doesn’t save tensors when I am in eval mode, right? Or do I have to take care of that explicitly?

Calling mod.eval() doesn’t actually disable recording of the autograd graph/saving for backward. You might be thinking of torch.no_grad, or mod.requires_grad_(False).
This might be of interest to you: Autograd mechanics — PyTorch 2.3 documentation

It’s tricky because there are few layers in this example, but maybe grouping the first linears into a single checkpoint call, and not checkpointing the last linear would be helpful.

Hi. Yes I forgot that we record the inputs in eval mode. In the meanwhile I solved my bug. It seems that everything is handled well when an autograd.Function is used. Thanks :slight_smile: