Accumulating model output blows up cuda memory?

As I accumulate the model outputs by concatenating them, my CUDA memory grows significantly more than the size of my torch tensors: “out” and “lab”.
I am accumulating the output because I need gradient accumulation downstream.

I’ve simplified my code down to the pytorch classifier example code:

import torch
import torchvision
import torchvision.transforms as transforms
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 16

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


device = "cuda" if torch.cuda.is_available() else "cpu"
out = torch.empty(0, 10).to(device)
lab = torch.empty(0, dtype=torch.int64).to(device)
net = Net().to(device)
for epoch in range(2):  # loop over the dataset multiple times
    for i, (inputs, labels) in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs = net(inputs)
        out = torch.cat((out, outputs), dim=0)
        lab = torch.cat((lab, labels), dim=0)

These tensors should be able to easily fit in memory but CUDA is quickly running out of memory.

In the sample code I gave, everything fits into memory because it’s a toy example, but the CUDA memory usage still grows significantly. After the code finishes nvidia-smi shows that my process is using 5868MiB. Meanwhile “lab” and “out” are float32 with size 100000 x 10, which should only be a few MiB.

I am guessing this might have something to do with the computational graph, but if I instead use the standard approach to gradient accumulation where the gradients are accumulated instead, this error doesn’t occur. Why is that?

This is expected since you are storing the entire computation graphs with these tensors including the intermediate activations needed to compute the gradients during the backward pass.

You can call .backward() on each loss and accumulate the gradients directly in the .grad attribute. If you don’t want to do it for some reason and want to call backward only once, your memory increase is expected as described above.

If I instead scaled up the batch size in the data loader, would this problem also happen? I thought the memory increase with a large batch size was solely due to the batch itself and not an increase in the computational graph.

Yes, you should see a similar memory increase.

The computation graph stores forward activations, which are needed for the gradient computation, and thus increases the memory. The actual nodes and metadata should be stored on the host and should not increase device memory.

1 Like

Thanks a lot! So does the computation graph get duplicated in size if I duplicate the number of items in the batch?

Yes, since all intermediates will use the same batch_size in their shape and will thus see the same factor in their size increase.

1 Like