Optimizer step requires GPU memory

When training a model, it seems, the optimizer occupies some GPU memory which it does not release anymore. Let me explain this with an example:

import torchvision.models as models
import torch
from torch import optim, nn

model = models.resnet18(pretrained=True).cuda()

optimizer = optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()

img = torch.rand((1, 3, 224, 224)).cuda()
label = torch.randint(0, 1000, (1,)).cuda()

# GPU Memory usage up to here: 849 MB

optimizer.zero_grad()

res = model(img)
loss = criterion(res, label)

# GPU Memory usage up to here: 855 MB

loss.backward()
# GPU Memory usage up to here: 917 MB

optimizer.step()
# GPU Memory usage up to here: 1021 MB

del loss, res, img, label
torch.cuda.empty_cache()

If I just initialize the model, I get 849 MB of GPU memory usage. Running a forward pass with a single image and then torch.cuda.empty_cache() increases the usage to 855 MB, fair enough.
Running the backward pass and and then torch.cuda.empty_cache() increases the memory usage to 917 MB, makes sense as the gradients are filled.

Now, running optimizer.step() and then torch.cuda.empty_cache() further increases the memory usage significantly. Could anyone tell me, why this is the case?

I sometimes run in to problem when alternatively training/validating, because of CUDA OOM error.

Cheers

1 Like

optimizer.step() clears the intermediate activations (if not kept by retain_graph=True), not the gradients.
You can still access the gradients using model.layer.weight.grad.

Since Python has function scoping (not block scoping), you could probably save some memory by creating separate functions for your training and validation as explained in this post (in case you haven’t done it already).

Thanks for the reply. So, if step() clears the intermediate activations, then why does the memory usage increase? For the weight update, it’s clear. But after that, empty_cache() should release the memory again, shouldn’t it?

Yes, I am using functions for training and evaluation. I just wanted to give a small example that’s easy to reproduce. But in this small example, I am explicitly deleting the variables, so nothing should be retained.

I think you are right and you should see the expected behavior, if you use an optimizer without internal states.
Currently you are using Adam, which stores some running estimates after the first step() call, which takes some memory.
I would also recommend to use the PyTorch methods to check the allocated and cached memory:

torch.cuda.memory_allocated()
torch.cuda.memory_cached()
2 Likes