Low gpu memory usage

Hello i am trying to run training for object detection by using ssdlitemobilenetv3 model which is available on torchvision library. I am using a dataset of 30000 images with a batch size of 16.
The problem is that when i look to the GPU performance on task manager on windows it shows a very low gpu utilization (3GB). I have an NVIDIA RTX A6000 with 48 GB.
I also tried to increase the bacth size to 256 images and the memory usage is higher (14GB) which in my opinion is too low for the particular case.
I tried to perform training on another model (UNet) which i used for another project and, using the same batch size (16), the memory used is 20 GB. So i concluded there is an error in my code, which i was not able to find by myself.
Can you help me with this issue?
Below is my code:

import torch
from torch.utils.data import DataLoader
from torchvision.models import MobileNet_V3_Large_Weights
import time, datetime
from tqdm import tqdm
from torchvision.utils import draw_bounding_boxes
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from torchvision.models.detection import ssdlite320_mobilenet_v3_large

import training.transforms as T
from training.utils import SSD_Dataset

def collate_fn(batch):
    return tuple(zip(*batch))

def get_transform(train, img_size, num_channels = 1, hflip_prob=0.5):

    if train:
        
        #return DetectionPresetTrain(img_size, num_channels, hflip_prob=hflip_prob)
        return T.Compose(
            [
                T.ClipBoxesToImage(img_size),
                    T.RandomHorizontalFlip(hflip_prob),
                    T.PILToTensor(num_channels),
                    T.ConvertImageDtype(torch.float),
                    T.Normalize()
            ]
        )
    
    else:

        #return DetectionPresetEval(img_size, num_channels)
        return T.Compose(
            [
                T.ClipBoxesToImage(img_size),
                    T.PILToTensor(num_channels),
                    T.ConvertImageDtype(torch.float),
                    T.Normalize()
            ]
        )
        
def train_one_epoch(model, opt, dataset, dataloader, device, epoch, epochs):

    model.train()

    train_losses = 0
    with tqdm(total=len(dataset), desc=f'Epoch {epoch+1}/{epochs}', unit='img') as pbar:

        for _, (images, targets) in enumerate(dataloader):

            images = list(image.to(device) for image in images)
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

            loss_dict = model(images, targets)
            losses = sum(loss for loss in loss_dict.values())

            loss_value = losses.item()

            train_losses += loss_value

            opt.zero_grad() 

            losses.backward()
            opt.step()

            pbar.update(len(images))
            pbar.set_postfix(**{'loss (batch)': loss_value})

    return train_losses

def evaluate(model, dataloader, dataset, device, epoch, epochs):

    model.eval()

    accuracy = []

    with tqdm(total= len(dataset), desc= f'Validation Epoch {epoch+1}/{epochs}', unit='img', leave=False) as pbar:

        for _, (images, _) in enumerate(dataloader):

            images = list(image.to(device) for image in images)

            with torch.no_grad():

                loss_dict = model(images)

                for i in range(len(images)):
                    scores = loss_dict[i]['scores']

                    accuracy.append(torch.mean(scores))

            pbar.update(len(images))
            pbar.set_postfix(**{'avg accuracy': torch.mean(torch.stack(accuracy))})

        if (epoch+1)%50 == 0:

            boxes = loss_dict[0]['boxes']
            scores = loss_dict[0]['scores']
            inf_image = draw_bounding_boxes(torch.tensor(images[0]*255, dtype=torch.uint8), boxes[scores > 0.5], width=1)

            inf_img_pil = Image.fromarray(inf_image.permute(1,2,0).detach().numpy())

            inf_img_pil.save('inf_img_%d.png' %(epoch+1))

    return accuracy

def train_and_test_model(MP, i):

    """  
    Main function for training and testing models

    Arguments:

    MP: dict
        parameters dictionary
    i: int
        iteration number (when iterating over hyper params)
    
    """

    print('Loading Data')

    img_train_path = MP['img_train_dir']
    img_val_path = MP['img_val_dir']
    train_anno_path = MP['train_anno_dir']
    val_anno_path = MP['val_anno_dir']
    bs = MP['bs']
    PIN_MEMORY = MP['pin_memory']
    DEVICE = MP['device']
    num_channels = MP['num_channels']
    img_size = MP['dim'] 
    n_train = MP['n_train']
    n_val = MP['n_dev']
    epoch_save = MP['epoch_save']

    train_dataset = SSD_Dataset(img_train_path, train_anno_path, n_train, transform=get_transform(True, img_size, num_channels))
    val_dataset = SSD_Dataset(img_val_path, val_anno_path, n_val, transform=get_transform(False, img_size, num_channels))

    print('Creating Dataloaders')

    train_collate_fn = collate_fn

    train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=bs, collate_fn=train_collate_fn, pin_memory=PIN_MEMORY)
    val_dataloader = DataLoader(val_dataset, shuffle=False, batch_size=bs, collate_fn=train_collate_fn, pin_memory=PIN_MEMORY)

    print('Loading Model')

    lr = MP['lr']
    wd = MP['wd']
    output_dir = MP['output_dir']
    #model = ssdlite_v3.ssdlite320_mobilenet_v3_large(weights=None, num_classes=2, weights_backbone=MobileNet_V3_Large_Weights.DEFAULT, trainable_backbone_layers=3).to(DEVICE)
    model = ssdlite320_mobilenet_v3_large(weights=None, num_classes=2,
                                                     weights_backbone=MobileNet_V3_Large_Weights.DEFAULT,
                                                     trainable_backbone_layers=3).to(DEVICE)
    optimizer = MP['optimizer']
    parameters = [p for p in model.parameters() if p.requires_grad]

    print(parameters[0].device)

    if optimizer == 'sgd':
        opt = torch.optim.SGD(
            parameters,
            lr=lr,
            momentum=0.9,
            weight_decay=wd,
        )
    elif optimizer == 'adamw':
        opt = torch.optim.AdamW(parameters, lr=lr, weight_decay=wd)

    # callbacks
    early_stopping = EarlyStopping(tolerance=7, verbose=True, path=output_dir)

    print('Start Training')

    start_epoch = MP['start_epoch']
    epochs = MP['epochs']

    H = {'train_loss': [], 'accuracy': []}

    train_steps = len(train_dataset)// bs
    #val_steps = len(val_dataset)// bs

    start_time = time.time()

    for epoch in range(start_epoch, epochs):
        
        epoch_time = time.time()
        train_loss = train_one_epoch(model, opt, train_dataset, train_dataloader, DEVICE, epoch, epochs)

        accuracy = evaluate(model, val_dataloader, val_dataset, DEVICE, epoch, epochs)

        train_loss_avg = train_loss / train_steps
        #val_loss_avg = val_loss / val_steps

        print('Train_Loss for epoch %d: %f' %(epoch+1, train_loss_avg))

        H['train_loss'].append(train_loss_avg)
        H['accuracy'].append(max(accuracy).cpu().detach().numpy())
        
        # epoch_len = len(str(epochs))
        # print_msg = (f'[{epoch+1:>{epoch_len}}/{epochs:>{epoch_len}}] ' +
        #              f'train_loss: {train_loss_avg:.5f} ' +
        #              f'valid_loss: {val_loss_avg:.5f}')

        # print(print_msg)

        #early_stopping(max(accuracy).cpu().detach().numpy(), model)

        # if early_stopping.early_stop:
        #     print('early stop at epoch:', i)
        #     break

        time_elapsed = time.time() - epoch_time
        print(f'Epoch Training completed in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
        #print(f'Best accuracy: {early_stopping.val_loss_min:.6f}')
        print(f'Epoch Best Classification Accuracy: {max(accuracy).cpu().detach().numpy():.4f}')

        if (epoch+1) in epoch_save:
            print('saving model at epoch %d' %(epoch+1))
            torch.save(model.state_dict(), output_dir)

        # load the last checkpoint with the best model
        #model.load_state_dict(torch.load(output_dir))
    torch.save(model.state_dict(), output_dir)
    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print(f"Training time {total_time_str} seconds")

Since the used model is a mobilenet I would assume it was carefully designed to not waste system resources. From the paper:

MobileNetV3 is tuned to mobile phone CPUs through a combination of hardware-aware network architecture search (NAS) complemented by the NetAdapt algorithm and then subsequently improved through novel architecture advances.

Given that their intended use case is to run on mobile platforms, I expect to see a small memory footprint.

@ptrblck yes but I imagine this behavior is true in the case of inference. If I need to train this model with a large dataset, if the memory usage is low it means that also the training time will be low (with the speed I obtain I imagine it will take more than five days to train instead of maybe one or two days). I don’t know if this is how it is supposed to be.

I don’t quite understand this logic. Why would a small model train slower than a large model?
Also, couldn’t you increase the batch size?

That’s what I was asking about. I thought there was an issue in my code due to this fact. I think the training speed should be at least equal if not greater than the UNet case.
In the mobilenet case the speed is 10 img/s approx. For the UNet case is 20 img/s which is double and this is very strange.

As I wrote in my first message I tried to increase to 256 and the memory of course increases to 14GB but it is still lower than the one used for UNet.

That’s why I don’t know if it should be this way or there is something wrong in my code.

Besides the size of the model the number and type of layers would also matter for the speed.
E.g. 100k linear layers with a size of 10x10 would see a large overhead from the dispatching, kernel lanches etc. while 10 huge linear layers (using the same memory) could execute faster.
In case your model suffers indeed from the dispatching mechanism, you could use CUDA Graphs as described here.

The model that i have used is the same that is present on the github page of pytorch/vision and it doesn’t seem to suffer from dispatching mechanism (at least from what i have understood).
I will try to debug again to see if i manage to find a solution, otherwise i would just have to wait for it to train :sweat_smile: