Very large use of memory with albumentations ReplayCompose?

Hello,

I’m facing the issue that my code is somehow using a far too large amount of RAMs that I would have expected

The example code is given below and requires

einops numpy albumentations torch cv2

It creates 3d volumes, compute a transform to be applied on every slice of the 3d volume; The same transform must be applied to all the slices, hence I used albumentation ReplayCompose.

My volumes have 40 slices and are (40, 1, 512, 512). As float32, I would have expected this to take 40MB in RAM. Even with a single worker, and a batch size of 8, the memory usage is larger than 16 GB ; I did an experiment where I killed the process when it reached above 16 GB. I would have expected a usage of 8 x 40MB = 320 MB or so , but clearly not as much as 16 GB.

To test the code :

python test.py --use_transforms

If you omit the --use_transforms option, it will discard the albumentations augmentation, and the memory usage will be much lower.

Do you believe that such a RAM usage is expected or that there is an issue somewhere ?

Thank you for your help.

# Standard imports
import logging
import functools
import operator
import argparse

# External imports
from einops import rearrange
import numpy as np
import albumentations as A
from torch.utils.data import Dataset
from albumentations.pytorch import ToTensorV2
import torch
import tqdm
import cv2


def compute_mean_std(dataset, batch_size=128, num_workers=4):
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        num_workers=num_workers,
    )
    # Compute the mean and std over minibatches of the
    # provided dataset
    mean = 0.0
    mean2 = 0.0
    nsamples = 0.0
    for imgs, _ in tqdm.tqdm(loader):
        mean += imgs.sum()
        mean2 += (imgs**2).sum()
        nsamples += functools.reduce(operator.mul, imgs.shape)
    mean /= nsamples
    mean2 /= nsamples

    std = torch.sqrt(mean2 - mean**2)

    return mean.item(), std.item()


class DatasetRandom(Dataset):
    def __init__(self, num_slices, transforms=None):
        super()
        self.num_slices = num_slices
        self.transforms = transforms

    def __getitem__(self, idx):

        input_chunk = np.random.random((self.num_slices, 1, 512, 512)).astype(
            np.float32
        )
        output_chunk = np.random.randint(
            low=0, high=2, size=(self.num_slices, 1, 512, 512)
        )

        # Apply the transform on the chunk
        if self.transforms is not None:
            t = A.ReplayCompose([self.transforms], p=1.0)
            input_tensors = []
            masks = []
            replayed_transform = None
            for i in range(input_chunk.shape[0]):
                if i == 0:
                    transformed = t(image=input_chunk[i], mask=output_chunk[i])
                    replayed_transform = transformed["replay"]
                else:
                    transformed = A.ReplayCompose.replay(
                        replayed_transform,
                        image=input_chunk[i],
                        mask=output_chunk[i],
                    )
                input_tensors.append(transformed["image"])
                masks.append(transformed["mask"])

            input_tensor = torch.stack(input_tensors)
            target_mask = torch.stack(masks)
        else:
            input_tensor = torch.from_numpy(input_chunk)
            target_mask = torch.from_numpy(output_chunk)

        # If requested, we can stack the slices along the channel dimension
        # This allows the sequence to be processed by a standard
        # convolutional network
        input_tensor = rearrange(input_tensor, "t c h w -> (t c) h w")
        target_mask = rearrange(target_mask, "t c h w -> (t c) h w")

        return input_tensor, target_mask

    def __len__(self):
        return 100000


def get_knotbil_dataloaders(
    data_config, preprocess_transforms, augmentation_transforms
):
    batch_size = data_config["batch_size"]
    num_workers = data_config["num_workers"]
    num_slices = data_config["num_slices"]
    use_transforms = data_config["use_transforms"]

    logging.info("  - KnotBil Dataset creation")

    # Compute the normalization metrics on the training fold
    conversion_transform = ToTensorV2(transpose_mask=True)
    if use_transforms:
        transforms = A.Compose(
            [preprocess_transforms, augmentation_transforms, conversion_transform]
        )
    else:
        transforms = None

    normalizing_dataset = DatasetRandom(
        num_slices=num_slices,
        transforms=transforms,
    )

    logging.info("      - Iterating the normalizing dataset")
    mean, std = compute_mean_std(
        normalizing_dataset, batch_size=batch_size, num_workers=num_workers
    )


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO)

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--use_transforms", action="store_true", help="Use the transform"
    )
    args = parser.parse_args()

    # Only used if use_transform is True
    size = 256

    preprocess_transforms = A.Compose(
        [
            A.Resize(size, size),
        ]
    )

    augmentation_transforms = A.Compose(
        [
            A.HorizontalFlip(),
            A.Affine(
                translate_percent=0.2,
                scale=(0.7, 1.3),
                keep_ratio=True,
                rotate=(-360, 360),
                border_mode=cv2.BORDER_CONSTANT,
                fill=0,
                fill_mask=0,
            ),
        ]
    )

    train_loader, valid_loader, input_size, output_size, normalizing_metrics = (
        get_knotbil_dataloaders(
            {
                "batch_size": 8,
                "num_workers": 1,
                "num_slices": 40,
                "use_transforms": args.use_transforms,
            },
            preprocess_transforms=preprocess_transforms,
            augmentation_transforms=augmentation_transforms,
        )
    )

I possibly reduced the problem by implementing two almost equivalent pipelines with torchvision and albumentations and filled in an issue on the albumentations github : Much larger memory consumption compared to torchvision v2 ? · Issue #2556 · albumentations-team/albumentations · GitHub

It may have nothing to do with ReplayCompose although I still do not know where the problem is but I suspect there is something wrong with albumentations or in the way I’m using it.