Trying to understand torch.utils.checkpoint

I am trying to understand how to use checkpoints to optimize my training. My basic understanding was that it trades increased compute for a lower memory footprint (by re-computing instead of storing data for the backward pass). Naively then I would assume that any time I use it I should decrease memory use and increase compute time. As a first pass I plugged it into a model I am training which uses a pre-trained resnet with an additional trainable parameter after it (I am using the CPC InfoNCE loss to pretrain a vision model). I noticed that it not only produced a significant decrease in memory use but also a substantial decrease in training time. I was able to increase batch size 8x and also decreased training time by 3x. I don’t understand why this would be, and am hoping someone can help me understand what is going on here.

Below is a self contained script that will reproduce this behavior on my machine (sorry it could be more minimal, but I figure the answer will be something I don’t understand about checkpoints, not in the details of my code :slight_smile: ):

when I set CHECKPOINT = False it uses about 8.5G of GPU memory and runs ~2.5 batches per second. When CHECKPOINT = True it uses 1.5G of GPU memory and runs at ~8.5 batches per second…

import torch
from torch import nn
from import Dataset, DataLoader
import numpy as np
from tqdm.notebook import tqdm

from torch import optim
import torchvision.models as models
from torch import nn

dev = "cuda:0"

class ImageDataset(Dataset):
    def __init__(self,length = 100000,size = 244):
        self.length = length
        self.size = 244
    def __len__(self):
        return self.length
    def __getitem__(self,idx,display = False):
        return torch.from_numpy(np.random.randn(2,3,self.size,self.size))
train = ImageDataset()
trainloader = DataLoader(
    batch_size = BATCH_SIZE,
    num_workers = 24,
    pin_memory = True

resnet = models.resnet50(pretrained = False)

class MODEL(nn.Module):
    def __init__(self,model):
        self.model = model
        self.LR = nn.Linear(1000,1000)
    def forward(self,x):
        if CHECKPOINT == False:
            o1 = self.model(x[:,0])
            o2 = self.model(x[:,1])
            o1 = torch.utils.checkpoint.checkpoint(self.model,x[:,0])
            o2 = torch.utils.checkpoint.checkpoint(self.model,x[:,1])
        return torch.mean((self.LR(o1)-o2)**2)
resnet = MODEL(resnet).to(dev)

optimizer = optim.SGD(resnet.parameters(),lr = .001)

for T in tqdm(trainloader):
    out = torch.mean(resnet(T.float().to(dev)))


This is good news, but it is a bit surprising indeed.
It will be doing the forward twice instead of once, so it should not become that much faster.
Also you’re check-pointing the whole model so the peak memory usage shoudn’t be reduced that much.

How do you measure the memory and time?
Do you have anything else running on the GPU or other users using the machine?

I am measuring the time by looking at the “iterations per second” measure in the tqdm output and I am measuring GPU memory usage by running nvidia-smi in the terminal. The real “measure of GPU memory use” for me is actually the ability to increase the batch size.

For example, on my machine (running on just 1 rtx 2080 TI) I can run a batch size of 256 in the above script with CHECKPOINT = True but get a CUDA OOM error when I run it with that batch size and CHECKPOINT = False (I even get an OOM with CHECKPOINT = False with batch size of 64, I need to go all the way down to 32 to get it to run). I am totally open to this being the result of some terrible optimization (or lack thereof) in my code that checkpoint is for some reason resolving :sweat_smile:

Oh and no, there is nothing else running on this machine, and no other users.

That is surprising indeed.
The memory when you look at it like that can have quite surprising behaviors (caching allocator, the gpu doing fun stuff, and fragmentation). But this looks a bit too much.

For the runtime, keep in mind that the cuda API is asynchronous and so the tqdm timing might be surprising as well (like if the loading of the sample in the iterator actuall spend more or less time because it needs to wait on the GPU).

Mhmm. Well, when I am training my actual model I know the training time decreases because the wall-clock time of each epoch is less, certainly that isn’t being affected by the async. :rofl:

And as I said, the batch size I am able to train on (without an CUDA out of memory error) is much higher.

At first I thought the speedup was some nuance of data IO with the larger batch size (like, fewer I/O operations from main memory to GPU?) but with the above script and BATCH_SIZE = 32 I see the decrease in memory use (at least as reported by nvidia-smi) and decrease in training time even with a fixed batch size.

I was also able to reproduce the behavior with this script on another machine (with a Quadro RTX 6000). Note, too, that when I run it on my actual model I am loading data off the SSD, just in case you thought it might be something to do with the contrived nature of this example script…

Thanks for your thoughts, I am just looking for understanding. My first thought was “surely I am confused and somehow this is reducing my effective batch size?” but I am pretty sure it is in fact running the whole model and using the larger batches… it is also perhaps notable that both with and without the checkpoint I am getting high-80s-mid-90s % GPU utilization (as reported by nvidia-smi). I am pretty sure it is reliably slightly higher with the checkpoint than without, which I guess is what you’d expect since presumably it is re-computing the forward pass. Its just such a dramatic effect I would like to understand how to leverage it every time I train a model, or understand what I am doing wrong…

checkpoint is for cases where at least one argument has requires_grad=True, using it like that skips self.model(i.e. resnet) training.


Ah I didn’t think to check that the parameters are changing, and it appears they are due to the tuneable parameters after the checkpoint, but indeed this seems to be the case. If I add a line to print the change in a parameter (like a conv. weight) at each step, it is indeed zero when the checkpoint is on. And this explains everything that I observed. Thank you.

That is also strange because after pretraining using this my supervised learning did slightly better… but I suppose that may have been random chance because it was only slight (surprisingly slight!).

I guess I don’t fully understand how to use a checkpoint, then. Any chance you could give a short example of it in action? When I looked for some I found another person asking for examples who never did: Do you have examples of the usage of torch.utils.checkpoint

x2= checkpoint(f2,x1) #f2(x1)

Without checkpoint, x2.grad_fn may keep intermediate tensors created in f2() alive. With checkpoint, x2.grad_fn schedules re-run of f2 with “gradient tape”. Point of this is to have more free memory for f3().

1 Like

So does this only work for situations in which f2 is just a function/series of functions, and not itself dependent upon parameters which needs gradients? If f2 = lambda x: x * parameter where parameter is some tuneable parameters that needs a gradient, will the above code compute gradients for parameter and update it when I call optimizer.step()? If yes, then how would I modify the initial code I wrote above to work and update model parameters (even though I understand that it will not in fact be helpful to use it)?

It updates other gradients, yes. As for second question, try doing x.requires_grad_(), but I doubt that checkpoint will be useful with such a big function, selective checkpointing of some resnet’s layers may do better.

Yes I understand that now, I am just trying to understand the details. So if I understand correctly checkpoint(f,x) will cause parameters within f not to get updated so long as x is the raw data, or just anything for which requires_grad is False? Is there some design reason for this? It seems like a non-obvious and inconvenient behavior.

Edited to say: Ahhh, is it that f(x) is the thing it throws out on the forward pass (and has to recompute) and instead of checking is anything in f needs gradients it only checks if x needs a gradient? That would make sense to me.

Kinda. Layer’s output itself goes to next layer, and may be early released or not depending on what is done there. But output’s “history” references other tensors, sometimes these are the only live references.

Re: checkpoint behaviour. I think it is mostly an implementation quirk, related to how autograd.Function works. But there also may be an assumption that you’d exclude at least the first gradient inducing operation from checkpointing, as you have no “history” (=potentially disposable tensors) at that point, so rerun is just harmful.

1 Like