Weird batch size memory behavior

I have two the same object detection datasets in coco annotation format with different image sizes: one is 233MB (images are of variable size) and another one is 13MB (max side is 300px).

For BOTH datasets I cannot fit more than 5 images per batch using fasterrcnn_resnet50_fpn model on 24GB GPU.

This is how memory looks (commas separate differences in memory consumption for different runs):

Before executing script = 2.3GB

small ds 13 MB:
batch = 1 → 5.8GB
batch = 2 → 8.8 GB
batch = 3 → 12.4GB
batch = 4 → 12.6GB
batch = 5 → 22.6GB, 18.8GB, 23.4GB, 22.9
batch = 8 → 23GB

large ds 233MB:
batch = 1 → 6.5GB, 6.7GB
batch = 2 → 9.6GB
batch = 3 → 11.3GB, 13.2GB, 11.3GB
batch = 4 → 12.7GB, 15.5GB
batch = 5 → 22.7GB
batch = 8 → 22.9GB

I wonder why there is no differences between both datasets (probably because they are both small enough) and why there is such weird behavior: increasing batch from 5 to 8 takes almost no extra GPU memory, while going from 4 to 5 takes many GBs of memory.
In any case, adding one extra image per batch causes huge increase in memory consumption, why is it so? Is it because of specific model architecture? I remember packing hundreds of images on 24GB GPU for Imagenet on classification models. I am trying to understand why in this scenario I can use only tiny batches of 5…

Img transformations consist of just turning image to Tensor:

def tensor_transform() -> torchvision.transforms.Compose:
    custom_transforms = []
    custom_transforms.append(torchvision.transforms.ToTensor())
    return torchvision.transforms.Compose(custom_transforms)

and dataset is created asL

    cppe5 = Cppe5(
        root=train_data_dir,
        annotation=train_annotation_file,
        transforms=tensor_transform(),
    )

    return torch.utils.data.DataLoader(
        cppe5,
        batch_size=train_batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        collate_fn=collate_fn,
    )

Entire dataloader object code is here.

Thank you!

How is the memory consumption being measured? Is memory stats being used? torch.cuda.memory_stats — PyTorch 1.12 documentation