ResNet-50 takes 10.13GB to run with batch size of 96

I have been working on using a ResNet-50 and have images of shape (3, 256, 256) and I’m trying to run it in batch size of 96, but I keep getting an error stating that it ran out of CUDA memory. I have a 3090 so I find that hard to believe since it has 24GB. So I then tested a very simple structure with this code, which showed that it takes 10.13GB to run with a shape (96, 3, 256, 256). Is this correct, I thought it would take a lot less allocated memory to run such a network? I checked and CUDA is running on my device and is available.

import torch
import torchvision.models as models

device = torch.device('cuda:0'if torch.cuda.is_available() else 'cpu')

def print_allocated_memory():
   ...:     print("{:.2f} GB".format(torch.cuda.memory_allocated() / 1024 ** 3))

resnet50_input = torch.ones(96, 3, 256, 256).float().to(device)

resnet50 = models.resnet50(pretrained = False).to(device)

encoded = resnet50(resnet50_input)


10.13 GB

A memory usage of ~10GB would be expected for a ResNet50 with the specified input shape.
Note that the input itself, all parameters, and especially the intermediate forward activations will use device memory.
While the former two might be small the intermediates (which are needed for the gradient computation) would most likely use the majority of the memory.
A quick way of checking it would be to register forward hooks and check all output activations.
Note that this approach is not exact, as you would need to check which operations are performed inplace etc., but would at least give you an approx. estimation:

nb_acts = 0
def count_output_act(m, input, output):
    global nb_acts
    nb_acts += output.nelement()

for module in resnet50.modules():
    if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear) or isinstance(module, nn.BatchNorm2d):

Using this approach, I get an estimation of:

nb_params = sum([p.nelement() for p in resnet50.parameters()])
nb_inputs = resnet50_input.nelement()
print('input elem: {}, param elem: {}, forward act: {}, mem usage: {}GB'.format(
    nb_inputs, nb_params, nb_acts, (nb_inputs+nb_params+nb_acts)*4/1024**3))

> 10.127978801727295
input elem: 18874368, param elem: 25557032, forward act: 2787211008, mem usage: 10.548689991235733GB

which comes close to the reported value.
If you don’t want to train the model and would like to save memory by not storing the intermediates, wrap the forward pass into with torch.no_grad():, which will use ~0.167GB.

Great! Thank you so much, glad to know I set up everything correctly