How to free GPU memory? (and delete memory allocated variables)

I am using a VGG16 pretrained network, and the GPU memory usage (seen via nvidia-smi) increases every mini-batch (even when I delete all variables, or use torch.cuda.empty_cache() in the end of every iteration). It seems like some variables are stored in the GPU memory and cause the “out of memory” error. I couldn’t solve the problem by using any of the other related posts in this forum.

Will you please help me understand how I can free all possible GPU memory after each mini-batch? If possible, will you please explain to me why some variables are stored in the GPU memory and are deleted from the memory when using the “del” command?

Attached below is a minimal example that reproduces the “out of memory” error I get

Thanks a lot

transform = transforms.Compose([transforms.RandomResizedCrop(224),
transforms.ToTensor(),
 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = dset.ImageFolder(root="......", transform=transform  )

model = models.vgg16(pretrained=True)
num_features = model.classifier[6].in_features
features = list(model.classifier.children())[:-1]  # Remove last layer
features.extend([nn.Linear(num_features, 2)])  # Add our layer with 2 outputs
model.classifier = nn.Sequential(*features)  # Replace the model classifier

optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0005)
criterion = nn.CrossEntropyLoss()

model = model.cuda()
criterion = criterion.cuda()

train_loader = DataLoader(trainset, batch_size=4, shuffle=True, drop_last=True)
train_iterator = iter(train_loader)

for i in range(num_of_mini_Batches):
 img, label= next(train_iterator)

 img = Variable(img.cuda(), requires_grad=True)
 label= Variable(label.cuda() )
 optimizer.zero_grad()

 outputs = model(img)
 loss = criterion(outputs, label)
 loss.backward()
 # del loss, model, outputs, optimizer, img, label, train_loader, train_iterator
 # torch.cuda.empty_cache()
 optimizer.step()

Is that the complete code or do you create any logs etc. after the optimization?
Usually you will run out of memory, if you store a tensor with its computation graph. E.g. if you use total_loss += loss instead of total_loss += loss.item().

10 Likes

This is the complete code… In each successive iteration, the memory usage increases, until it runs out of memory (after two iterations).

Will you please help me understand how can I fix it?

Thanks!

That’s strange, as I cannot see any obvious reason for the memory growth.
Also. I used your code with a fake dataset just to make sure I’m not overlooking something, and your code runs just fine. The memory stays at the same level through 1000 iterations.

maybe it’s a dataset issue?
do you use a custom dataset?

Before starting the “for” loop, the memory usage is as follows:

Stopping the first iteration before entering the line “loss.backward()” results in the following memory usage:

After the line “loss.backward()”, the known “out of memory” error is shown.

Decreasing the batch size to “2” allows me to run through the first iteration, but the “out of memory” error is then shown during the second iteration (the same holds for a batch size of 1…).

Is it possible that the 4GB RAM available cannot handle such a small batch size (with RGB pictures of 224X224)?

Thanks!

Hi iariav - I am using the ants and bees dataset from the transfer learning tutorial.

Got any idea? Thanks a lot!

If the first iteration was successful, the second should also work. You can remove the requires_grad=True argument from your input and try it again. This would give you some more memory.

ptrblck - isn’t the default for requires_grad is “True”? (so that even after I delete this part, the variable would still be True?)

After removing the “requires_grad = True”, and with a batch size of “1”, the code runs with a GPU memory usage of 3.4 GB. Increasing the batch size to “2” results in an “out of memory” error in the second iteration… Is it possible that 4GB of RAM are not enough for a batch size of “2” in this case? (as mentioned in a previous comment - it is the ants and bees dataset from the transfer learning tutorial)

Looking forward for your answer. Thanks a lot!

It might be, even though I’m wondering why it’s running out of memory in the second iteration.
I’ll have a look at the memory usage a bit later.

No, the default for Variables was requires_grad=False.
You could also update to PyTorch 0.4.0, where Variables and tensors were merged besides some other bug fixes and new features.

Great. Thank you. Will be glad to get your insights about the memory usage in your computer later (BTW- will you please tell me how much memory usage does your computer show when the batch size is 4? ).

I am using Pytorch 0.4, but still using “Variable” just from the habit. I guess that the only thing that would change in the Pytorch 0.4 syntax is that I would have to delete the name “Variable” and just leave “img = img.cuda()”.

Thanks a lot again

The training takes ~3777MB on my system for a batch size of 4 (GTX 1070, CUDA9, cuDNN7, compiled from master).

1 Like

Well, you obviously have a much better GPU than mine :joy:

In any case, even after updating the GPU driver and CUDNN version, the program still gets stuck in the second iteration. Is there any kind of debugging that I can perform that will allow a better understanding of the problem?
Will be glad to receive your guidance.

EDIT: it seems like the program now gets stuck before the “optimizer.step()” in the second iteration. The reason for the error is: “denom = exp_avg_sq.sqrt().add_(group[‘eps’])” in the Adam optimizer routine.

Thanks a lot again.

You could try to see the memory usage with the script posted in this thread.

Do you still run out of memory for batch_size=1 or are you currently testing batch_size=4?
Could you temporarily switch to an optimizer without tracking stats, e.g. optim.SGD?

1 Like

Thanks a lot for the reference to the memory and cpu usage methods.

After deleting the “requires_grad = True”, the program now runs for batch_size=1 (with 2 GB of 4 GB RAM used) but get stuck in the second iteration when I use batch_size = 2. Using the method cpuStats() before and after the line optimizer.step() shows that it still uses 2 GB of GPU RAM, but get “out of memory” during the optimizer.step() call in the second iteration, with the error reported as:

    denom = exp_avg_sq.sqrt().add_(group['eps'])
RuntimeError: cuda runtime error (2) : out of memory at c:\programdata\miniconda3\conda-bld\pytorch_1524546371102\work\aten\src\thc\generic/THCStorage.cu:58

Changing the optimizer to optim.SGD (i.e. defining optimizer = optim.SGD(model.parameters() , lr=0.001)) indeed allows me to run through the program with a batch size of 10! It also seems like the memory usage stays around 2 GB, even when I increase the batch size even more…

Is there any way to use Adam algorithm without getting “out of memory” in my case?

One more question - when I am not specifying “requires_grad = True”, how does the line loss.backward() work (where loss = criterion(outputs, label)) ? shouldn’t I put requires_grad=True for at least one of the variables outputs, or label? (or, use “with torch.set_grad_enabled”)

Thanks a lot!

You could have a look at checkpointing your model, which trades memory for compute.

I think you don’t need the gradients for the input or are you trying to manipulate the input itself?
The model parameters automatically require gradients. In a standard classification setup, you don’t need to specify requires_grad=True for neither the input nor the label.

2 Likes

Thanks a lot ptrblck! If you have any reference for an example that uses checkpointing - it will be great. Thanks a lot for all of your help !

@Priya_Goyal created a good example here.

4 Likes