Efficiently handling large-scale vision datasets

Hi, I’m trying to pretrain a ViT using a self-supervised training framework (DINOv2) on a large dataset (between 100M and 1B jpg images, all of size (256, 256)). I draw inspiration from the DINOv2 dataset class (here) for handling and loading data.

This consists in using a combination of tarball files for storing images, and a single npy file for metadata (start and end offsets + information to know in which tarball file a given image is located). I put the code snippet below.

Unfortunately, I am facing very slow data loading times:

  1. Large tarball files: some tarballs I work with containing as many as 6M images. I suspect this increases RAM usage, which could explain the to slow data loading times – or even out-of-memory errors – I face.

  2. To mitigate this issue, I split the large tarballs into smaller ones (of 1Gb). Despite offering some relief by reducing the memory footprint during data loading, this solution doesn’t scale well with the batch size : the bigger the batch size, the more tarball files to open/close concurrently, which seems to add significant overhead as it slows the data loading process.

I’ve tried looking into alternative tools (WebDataset, TorchData), but wasn’t successful. I am therefore reaching out for any advice, or alternative strategies to handle large-scale vision datasets. Thank you!

Dataset code
import numpy as np

from io import BytesIO
from typing import Any
from PIL import Image
from pathlib import Path

from mmap import ACCESS_READ, mmap
from typing import Any, Callable, Optional, Tuple
from torchvision.datasets import VisionDataset
from functools import lru_cache


class Decoder:
    def decode(self) -> Any:
        raise NotImplementedError


class ImageDataDecoder(Decoder):
    def __init__(self, image_data: bytes) -> None:
        self._image_data = image_data

    def decode(self) -> Image:
        f = BytesIO(self._image_data)
        return Image.open(f).convert(mode="RGB")


class TargetDecoder(Decoder):
    def __init__(self, target: Any):
        self._target = target

    def decode(self) -> Any:
        return self._target


_DEFAULT_MMAP_CACHE_SIZE = 16  # Warning: This can exhaust file descriptors


def _get_tarball_path(dataset_name: str) -> str:
    return f"{dataset_name}.tar"


def _make_mmap_tarball(tarballs_root: str, mmap_cache_size: int):
    @lru_cache(maxsize=mmap_cache_size)
    def _mmap_tarball(dataset_name: str) -> mmap:
        tarball_path = _get_tarball_path(dataset_name)
        tarball_full_path = Path(tarballs_root, tarball_path)
        with open(tarball_full_path) as f:
            return mmap(fileno=f.fileno(), length=0, access=ACCESS_READ)

    return _mmap_tarball


class FoundationDataset(VisionDataset):

    def __init__(
        self,
        *,
        root: str,
        transforms: Optional[Callable] = None,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        mmap_cache_size: int = _DEFAULT_MMAP_CACHE_SIZE,
    ) -> None:
        super().__init__(root, transforms, transform, target_transform)
        self._get_entries()
        self._get_dataset_names()
        self._mmap_tarball = _make_mmap_tarball(self._tarballs_root, mmap_cache_size)

    @property
    def _tarballs_root(self) -> str:
        return self.root

    @property
    def _entries_name(self) -> str:
        return "pretrain_entries.npy"

    def _get_entries(self) -> np.ndarray:
        self._entries = self._load_entries(self._entries_name)

    def _load_entries(self, _entries_name: str) -> np.ndarray:
        entries_path = Path(self.root, _entries_name)
        return np.load(entries_path, mmap_mode="r")

    def _get_filepaths_dict(self, dataset_name: str):
        return self._load_filepaths_dict(dataset_name)

    def _load_filepaths_dict(self, dataset_name: str):
        filepaths_dict_path = Path(self.root, f"{dataset_name}_file_indices.npy")
        return np.load(filepaths_dict_path, allow_pickle=True).item()

    def _get_dataset_names(self) -> dict:
        self._dataset_names = self._load_dataset_names()

    def _load_dataset_names(self) -> dict:
        dataset_dict_path = Path(self.root, "dataset_indices.npy")
        return np.load(dataset_dict_path, allow_pickle=True).item()

    def get_image_data(self, index: int) -> bytes:
        entry = self._entries[index]
        file_idx, start_offset, end_offset, dataset_idx = (
            entry[1],
            entry[2],
            entry[3],
            entry[4],
        )
        dataset_name = self._dataset_names[dataset_idx]
        filepaths_dict = self._get_filepaths_dict(dataset_name)
        filepath = filepaths_dict[file_idx]
        class_mmap = self._mmap_tarball(dataset_name)
        data = class_mmap[start_offset:end_offset]
        return data, Path(filepath)

    def get_target(self, index: int) -> Any:
        return int(self._entries[index][0])

    def get_targets(self) -> np.ndarray:
        return self._entries[:, 0]

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        try:
            image_data, _ = self.get_image_data(index)
            image = ImageDataDecoder(image_data).decode()
        except Exception as e:
            raise RuntimeError(f"can not read image for sample {index} ({e})") from e
        target = self.get_target(index)
        target = TargetDecoder(target).decode()

        if self.transforms is not None:
            image, target = self.transforms(image, target)

        return image, target

    def __len__(self) -> int:
        return len(self._entries)