Accelerate ImageFolder-based dataset loading

Just a small update in case anyone might be interested.

I never managed to load more than ~1k samples/s using ImageFolder (with a standard Mve SSD).

I ended up placing the data in a HDF5 file. My hope was that accessing a single file would help. Moreover I think HDF5 saves the sample in a localised physical area in the SSD which improve loading (not so sure though…). Moreover I am resizing and transforming to Tensors when creating the HDF5 file.

I am loading about 7.5k samples/s at the moment (including .to("cuda")).

def create_hdf5_dataset(root_folder: str,
                        hdf5_file: str,
                        target_size: tuple[int, int] = (128, 128),
                        channels: int = 3) -> None:
    """ Create an hdf5 database file from a folder containing images.
    The images are resized to the target_size and stored in the hdf5 file.
    This is useful when your dataset cannot fit in RAM and the data consits of many images"""
    transform = transforms.Compose([
        transforms.Resize(target_size),
        transforms.ToTensor(),
    ])

    # check if data_dir points towards an existing directory
    if not os.path.isdir(root_folder):
        error_string = f"Directory '{root_folder}' not found."
        raise FileNotFoundError(error_string)

    # generate dataset
    try:
        data = ImageFolder(root=root_folder, transform=transform)
        # set chunk size for loading data and appending to hdf5 file
        chunk_size = min(int(len(data) / 10), 1000)
    except Exception as e:
        raise e

    num_images = len(data)

    # open the hdf5 file
    with h5py.File(hdf5_file, "w") as file:
        # Create datasets with chunks for efficient storage
        img_dataset = file.create_dataset("images", shape=(num_images,
                                                           channels,
                                                           target_size[0],
                                                           target_size[1]),
                                          dtype="float32", chunks=None)
        lbl_dataset = file.create_dataset("labels", shape=(num_images,),
                                          dtype="int64", chunks=None)

        loader = DataLoader(data, batch_size=chunk_size, shuffle=False)

        current_index = 0
        # start batch processing (to avoid RAM overflow)
        for images, labels in tqdm(loader):
            chunk_size = images.size(0)

            images = np.array(images)
            labels = np.array(labels)

            img = torch.tensor(images, dtype=torch.float32)
            lbl = torch.tensor(labels, dtype=torch.long)

            img_dataset[current_index:current_index + chunk_size] = img
            lbl_dataset[current_index:current_index + chunk_size] = lbl

            current_index += chunk_size

Loading is done using the following dataset which I stole from this post:

    """ Custom dataset for loading images from an hdf5 file.
    This system lazy loads the images from the hdf5 file."""

class Hdf5Dataset(Dataset):
    def __init__(self, hdf5_file: str, transform: transforms = None) -> None:
        self.hdf5_file = hdf5_file
        self.transform = transform
        self.dataset = None

        with h5py.File(hdf5_file, "r") as file:
            self.length = len(file["images"])

    def __len__(self) -> int:
        return self.length

    def __getitem__(self, idx) -> tuple[torch.Tensor, int]:
        """ Returns a tuple of image and label."""
        if self.dataset is None:
            self.dataset = h5py.File(self.hdf5_file, mode="r", swmr=True)

        image = self.dataset["images"][idx]
        label = self.dataset["labels"][idx]

        if self.transform:
            image = self.transform(image)

        return image, label

if __name__ == "__main__":
    """ Test the hdf5 dataset. vs the image dataset"""

    hdf5_file = r"datasets/birds/train.hdf5"

    batch_size = 32 * 4
    epochs = 4
    num_workers = 6
    prefetch_factor = 10
    persistent_workers = True

    ### loading dataset from hdf5 ###
    print("With hdf5")
    dataset = Hdf5Dataset(hdf5_file)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=num_workers,
                                             persistent_workers=persistent_workers,
                                             prefetch_factor=prefetch_factor,
                                             pin_memory=False)

    for _ in range(epochs):
        for _, (images, labels) in tqdm(enumerate(dataloader)):
            images = images.to("cuda")
            labels = labels.to("cuda")
            pass