Hello,
I’m facing the issue that my code is somehow using a far too large amount of RAMs that I would have expected
The example code is given below and requires
einops numpy albumentations torch cv2
It creates 3d volumes, compute a transform to be applied on every slice of the 3d volume; The same transform must be applied to all the slices, hence I used albumentation ReplayCompose.
My volumes have 40 slices and are (40, 1, 512, 512)
. As float32, I would have expected this to take 40MB in RAM. Even with a single worker, and a batch size of 8, the memory usage is larger than 16 GB ; I did an experiment where I killed the process when it reached above 16 GB. I would have expected a usage of 8 x 40MB = 320 MB or so , but clearly not as much as 16 GB.
To test the code :
python test.py --use_transforms
If you omit the --use_transforms
option, it will discard the albumentations augmentation, and the memory usage will be much lower.
Do you believe that such a RAM usage is expected or that there is an issue somewhere ?
Thank you for your help.
# Standard imports
import logging
import functools
import operator
import argparse
# External imports
from einops import rearrange
import numpy as np
import albumentations as A
from torch.utils.data import Dataset
from albumentations.pytorch import ToTensorV2
import torch
import tqdm
import cv2
def compute_mean_std(dataset, batch_size=128, num_workers=4):
loader = torch.utils.data.DataLoader(
dataset=dataset,
batch_size=batch_size,
num_workers=num_workers,
)
# Compute the mean and std over minibatches of the
# provided dataset
mean = 0.0
mean2 = 0.0
nsamples = 0.0
for imgs, _ in tqdm.tqdm(loader):
mean += imgs.sum()
mean2 += (imgs**2).sum()
nsamples += functools.reduce(operator.mul, imgs.shape)
mean /= nsamples
mean2 /= nsamples
std = torch.sqrt(mean2 - mean**2)
return mean.item(), std.item()
class DatasetRandom(Dataset):
def __init__(self, num_slices, transforms=None):
super()
self.num_slices = num_slices
self.transforms = transforms
def __getitem__(self, idx):
input_chunk = np.random.random((self.num_slices, 1, 512, 512)).astype(
np.float32
)
output_chunk = np.random.randint(
low=0, high=2, size=(self.num_slices, 1, 512, 512)
)
# Apply the transform on the chunk
if self.transforms is not None:
t = A.ReplayCompose([self.transforms], p=1.0)
input_tensors = []
masks = []
replayed_transform = None
for i in range(input_chunk.shape[0]):
if i == 0:
transformed = t(image=input_chunk[i], mask=output_chunk[i])
replayed_transform = transformed["replay"]
else:
transformed = A.ReplayCompose.replay(
replayed_transform,
image=input_chunk[i],
mask=output_chunk[i],
)
input_tensors.append(transformed["image"])
masks.append(transformed["mask"])
input_tensor = torch.stack(input_tensors)
target_mask = torch.stack(masks)
else:
input_tensor = torch.from_numpy(input_chunk)
target_mask = torch.from_numpy(output_chunk)
# If requested, we can stack the slices along the channel dimension
# This allows the sequence to be processed by a standard
# convolutional network
input_tensor = rearrange(input_tensor, "t c h w -> (t c) h w")
target_mask = rearrange(target_mask, "t c h w -> (t c) h w")
return input_tensor, target_mask
def __len__(self):
return 100000
def get_knotbil_dataloaders(
data_config, preprocess_transforms, augmentation_transforms
):
batch_size = data_config["batch_size"]
num_workers = data_config["num_workers"]
num_slices = data_config["num_slices"]
use_transforms = data_config["use_transforms"]
logging.info(" - KnotBil Dataset creation")
# Compute the normalization metrics on the training fold
conversion_transform = ToTensorV2(transpose_mask=True)
if use_transforms:
transforms = A.Compose(
[preprocess_transforms, augmentation_transforms, conversion_transform]
)
else:
transforms = None
normalizing_dataset = DatasetRandom(
num_slices=num_slices,
transforms=transforms,
)
logging.info(" - Iterating the normalizing dataset")
mean, std = compute_mean_std(
normalizing_dataset, batch_size=batch_size, num_workers=num_workers
)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
parser = argparse.ArgumentParser()
parser.add_argument(
"--use_transforms", action="store_true", help="Use the transform"
)
args = parser.parse_args()
# Only used if use_transform is True
size = 256
preprocess_transforms = A.Compose(
[
A.Resize(size, size),
]
)
augmentation_transforms = A.Compose(
[
A.HorizontalFlip(),
A.Affine(
translate_percent=0.2,
scale=(0.7, 1.3),
keep_ratio=True,
rotate=(-360, 360),
border_mode=cv2.BORDER_CONSTANT,
fill=0,
fill_mask=0,
),
]
)
train_loader, valid_loader, input_size, output_size, normalizing_metrics = (
get_knotbil_dataloaders(
{
"batch_size": 8,
"num_workers": 1,
"num_slices": 40,
"use_transforms": args.use_transforms,
},
preprocess_transforms=preprocess_transforms,
augmentation_transforms=augmentation_transforms,
)
)