Systematic reasoning and debugging of gpu memory

I would like to be able to approach GPU out of memory issues more systematically

  • Are there some resources that explain roughly when allocations happen and when memory is released?
  • Are there tools that can show which tensors are alive at a given time?

As a concrete example, I am building a network that operates on 3d medical images and GPU memory is an issue. My model contains dense net like building blocks that look like

tmp = torch.cat([lots, of, inputs], 1)
small_output = conv(tmp)

I suspect that tmp eats up a lot of memory and is computationally cheap.

  • How to tell pytorch to recompute tmp whenever it is needed instead of storing it for the backward pass?
  • If this is solved, how profile the better memory footprint?

May I suggest to start from this tutorial and dig dipper as needed? https://pytorch.org/tutorials/recipes/recipes/profiler.html

1 Like

I created two variants of a toy net with and without checkpointing (see end of post for code).
Profile looks essentially like


-------- ------------  ------------  ------------  ------------  ------------
    Name      CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls
-------- ------------  ------------  ------------  ------------  ------------
 forward     51.61 Kb        -276 b     394.53 Mb      -3.91 Mb             1
backward         -4 b        -276 b         512 b        -512 b             1
-------- ------------  ------------  ------------  ------------  ------------
    Name      CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls
-------- ------------  ------------  ------------  ------------  ------------
 forward         -4 b        -276 b       3.82 Gb     -39.45 Mb             1
backward    -51.61 Kb     -51.88 Kb     -77.07 Mb     -77.07 Mb             1

So great it seems checkpointing could save a lot of memory. However I don’t understand what exactly is reported here.

  • What s the difference between CUDA Mem and Self CUDA Mem?
  • What is reported here? Is it peak allocated memory? Why are some numbers negative like -39.45 Mb? Is it sum all allocations - sum of all deallocations?

In case somebody wants to play with it here is my code:

import torch
from torch import nn
import torchvision.models as models
import torch.autograd.profiler as profiler
import pytorch_lightning as pl

class Model(torch.nn.Module):
    
    def __init__(self, ninput, nhidden, ncat, nrepeat, save_memory):
        super().__init__()
        self.save_memory = save_memory
        self.nrepeat = nrepeat
        self.ncat = ncat
        self.layer1 = nn.Linear(ninput, nhidden)
        self.layer2 = nn.Linear(ncat*nhidden, ninput)
        
        
    def forward_loop_body(self, x):
        x = self.layer1(x)
        x = torch.cat([x for _ in range(self.ncat)], 1)
        x = self.layer2(x)
        return x
        
    def forward(self, x):
        for _ in range(self.nrepeat):
            if self.save_memory and x.requires_grad:
                x = torch.utils.checkpoint.checkpoint(self.forward_loop_body, x)
            else:
                x = self.forward_loop_body(x)
        return x


device = torch.device("cuda:0")
nb = 1024
x = torch.randn(nb,100).to(device)

for save_memory in [True, False]:
    model = Model(ninput=100, 
                  nhidden=1000,
                  ncat=100,
                  nrepeat=10,
                  # save_memory=False,
                  save_memory=save_memory,
                 ).to(device)
    
    criterion = torch.nn.MSELoss()
    
    with profiler.profile(record_shapes=True, profile_memory=True, ) as prof:
        with profiler.record_function("forward"):
            y = model(x)
        with profiler.record_function("backward"):
            model.zero_grad()
            loss = criterion(x,y)
            loss.backward()
            
    filename = f"profile_save_memory={save_memory}.txt"
    
    with open(filename, "w") as io:
        io.write(prof.key_averages().table())

I am not an expert in cuda memory profiling, sorry for that.
As I understand from tutorial:

Note the difference between self cpu time and cpu time - operators can call other operators, self cpu time exludes time spent in children operator calls, while total cpu time includes it.

It should be the same for self and total memory. Negatives are most likely releases of memory. For more information, we probably have to go into documentation or maybe in source.