Accelerate ImageFolder-based dataset loading

Hello,

First of all, sorry if the question as been asked. I’ve searched everywhere on this forum, tried everything I could find to no avail.

I am training a ViT on an image dataset fetched from Kaggle. The train set contains ~80’000 224X224X3 jpg (~2Go). Training is rather slow as the GPU is barely used (fast oscillation from 0% to 100%).

I am using a Dataset (with ImageFolder) and a Dataloader to test out the data fetch & transformation:

class BirdsDataset(Dataset):
    """ generator for the birds dataset
    contains 2625 classes """

    def __init__(self,
                 data_dir: str = "",
                 transform: transforms.Compose = None) -> None:

        # check if data_dir points towards an existing directory
        if not os.path.isdir(data_dir):
            error_string = f"Directory '{data_dir}' not found."
            raise FileNotFoundError(error_string)

        # set default transform (transforms.ToTensor())
        if transform is None:
            transform = transforms.Compose([
                transforms.ToTensor()
            ])

        # generate dataset
        try:
            self.data = ImageFolder(data_dir, transform=transform)
        except Exception as e:
            raise e

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, idx) -> tuple[torch.Tensor, int]:
        return self.data[idx]

    @property
    def classes(self) -> list[str]:
        return self.data.classes

The test script (loading and transforming all batches using a Dataloader) is as follow:

    # setting up transformation. 
    size = (128, 128)
    transformation = transforms.Compose([
        # equivalent to ToTensor() in TransformV2
        transforms.ToImage(),
        transforms.ToDtype(torch.float32, scale=True),
        transforms.Resize(size, antialias=True),
    ])

    # creating dataset
    train_data_dir = r"./datasets/birds/train"
    train_dataset = BirdsDataset(
        data_dir=train_data_dir,
        transform=transformation
    )

    # creating dataloader
    workers = 8  
    prefetch = workers * 12 
    batch_size = 256 
    pin = False
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  num_workers=workers,
                                  prefetch_factor=prefetch,
                                  persistent_workers=True,
                                  pin_memory=pin)

    # single pass over the entire dataset (eq. 1 epoch)
    for images, labels in train_dataloader:
        pass

Some (possibly) useful information:

  • Setup: CPU:5600x, RAM:32Go, GPU:3070ti, MveSSD, OS:Win11
  • First epoch is slower, the next ones need about 80s to iterate over the dataset.
  • Using this script, the best configurations for now is to use num_workers=8, and pin_memory=False. It is still very slow nontheless.
  • Passing the data to the GPU (images.to(device)) does not increase the overhead
  • More complex transformations increase a little bit the total time spent
  • When iterating over the train_loader, I clearly see “rapid” and “slow” iteration phases: rapid distribution of 8-10 batches (with 8 workers), then a pause.
  • All CPUs are in use during the process

I’ve been using the bottleneck profiler (python -m torch.utils.bottleneck ./test_script.py) in order to figure out what is going on but I can’t seem to make sense of it.

  • 1- In the bottleneck report, the function names are truncated (see below “enumerate(DataLoader)#_MultiProcessingDataLoaderIter…”) and I can’t figure out what is bottlenecking the script.
  • 2- CProfile indicates “_pickle.Pickler.dump” is 50% of the overhead: where is that coming from ? Can I avoid this somehow ?
  • 3- Is there anyway I can improve my Dataset class ? I am using the standard ImageFolder, I have tried dumping the images into an hdf5 file with no improvements.
  • 4- I can’t figure out if I am CPU or IO bound. Is the ImageFolder lazy loading in my situation ?
  • 5- Any suggestions or leads would be very much welcomed :).

Thank a lot !

  cProfile output
--------------------------------------------------------------------------------
         4777907 function calls (4743887 primitive calls) in 62.709 seconds

   Ordered by: internal time
   List reduced from 6636 to 15 due to restriction <15>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
       16   23.766    1.485   23.775    1.486 {method 'dump' of '_pickle.Pickler' objects}
      326   23.600    0.072   23.600    0.072 {built-in method _winapi.WaitForMultipleObjects}
   170676    1.262    0.000    2.162    0.000 <frozen ntpath>:154(splitdrive)
617670/617602    0.818    0.000    0.837    0.000 {built-in method builtins.isinstance}
    85324    0.689    0.000    3.074    0.000 <frozen ntpath>:107(join)
411067/408361    0.540    0.000    0.548    0.000 {built-in method builtins.len}
    84635    0.474    0.000    0.821    0.000 C:\Users\agarc\anaconda3\envs\ViT_py311\Lib\site-packages\torchvision\datasets\folder.py:10(has_file_allowed_extension)
   351280    0.459    0.000    0.459    0.000 {method 'append' of 'list' objects}
     7297    0.454    0.000    0.865    0.000 C:\Users\agarc\anaconda3\envs\ViT_py311\Lib\inspect.py:867(cleandoc)
        1    0.396    0.396    5.416    5.416 C:\Users\agarc\anaconda3\envs\ViT_py311\Lib\site-packages\torchvision\datasets\folder.py:48(make_dataset)
     1050    0.367    0.000    0.752    0.001 <frozen os>:345(_walk)
   256505    0.352    0.000    0.352    0.000 {method 'startswith' of 'str' objects}
   261062    0.340    0.000    0.340    0.000 {built-in method nt.fspath}
   121611    0.306    0.000    0.452    0.000 C:\Users\agarc\anaconda3\envs\ViT_py311\Lib\site-packages\torch\_dynamo\allowed_functions.py:183(<genexpr>)
   179566    0.246    0.000    0.246    0.000 {method 'replace' of 'str' objects}


--------------------------------------------------------------------------------
  autograd profiler output (CPU mode)
--------------------------------------------------------------------------------
        top 15 events sorted by cpu_time_total

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
enumerate(DataLoader)#_MultiProcessingDataLoaderIter...         7.62%     655.037ms         7.63%     655.122ms     655.122ms             1  
enumerate(DataLoader)#_MultiProcessingDataLoaderIter...         7.60%     652.877ms         7.60%     653.144ms     653.144ms             1  
enumerate(DataLoader)#_MultiProcessingDataLoaderIter...         7.03%     603.798ms         7.03%     604.047ms     604.047ms             1  
enumerate(DataLoader)#_MultiProcessingDataLoaderIter...         6.82%     586.095ms         6.82%     586.315ms     586.315ms             1  
enumerate(DataLoader)#_MultiProcessingDataLoaderIter...         6.80%     584.298ms         6.80%     584.479ms     584.479ms             1  
enumerate(DataLoader)#_MultiProcessingDataLoaderIter...         6.79%     583.468ms         6.79%     583.753ms     583.753ms             1  
enumerate(DataLoader)#_MultiProcessingDataLoaderIter...         6.79%     583.276ms         6.79%     583.556ms     583.556ms             1  
enumerate(DataLoader)#_MultiProcessingDataLoaderIter...         6.69%     574.790ms         6.69%     575.074ms     575.074ms             1  
enumerate(DataLoader)#_MultiProcessingDataLoaderIter...         6.66%     572.605ms         6.67%     572.862ms     572.862ms             1  
enumerate(DataLoader)#_MultiProcessingDataLoaderIter...         6.41%     550.299ms         6.41%     550.427ms     550.427ms             1  
enumerate(DataLoader)#_MultiProcessingDataLoaderIter...         6.37%     547.527ms         6.38%     547.777ms     547.777ms             1  
enumerate(DataLoader)#_MultiProcessingDataLoaderIter...         6.18%     530.958ms         6.18%     531.265ms     531.265ms             1  
enumerate(DataLoader)#_MultiProcessingDataLoaderIter...         6.11%     525.124ms         6.11%     525.368ms     525.368ms             1  
enumerate(DataLoader)#_MultiProcessingDataLoaderIter...         6.08%     522.613ms         6.09%     522.902ms     522.902ms             1  
enumerate(DataLoader)#_MultiProcessingDataLoaderIter...         6.04%     518.745ms         6.04%     519.024ms     519.024ms             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 8.592s

--------------------------------------------------------------------------------
  autograd profiler output (CUDA mode)
--------------------------------------------------------------------------------
        top 15 events sorted by cpu_time_total

	Because the autograd profiler uses the CUDA event API,
	the CUDA time column reports approximately max(cuda_time, cpu_time).
	Please ignore this output if your code does not use CUDA.

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
enumerate(DataLoader)#_MultiProcessingDataLoaderIter...        12.48%        1.008s        12.48%        1.009s        1.009s        1.008s        12.49%        1.009s        1.009s             1  
enumerate(DataLoader)#_MultiProcessingDataLoaderIter...         8.51%     687.610ms         8.52%     689.011ms     689.011ms     685.606ms         8.50%     689.029ms     689.029ms             1  
enumerate(DataLoader)#_MultiProcessingDataLoaderIter...         7.64%     617.761ms         7.65%     618.081ms     618.081ms     617.389ms         7.65%     618.148ms     618.148ms             1  
enumerate(DataLoader)#_MultiProcessingDataLoaderIter...         6.86%     554.331ms         6.86%     554.761ms     554.761ms     553.872ms         6.86%     554.785ms     554.785ms             1  
enumerate(DataLoader)#_MultiProcessingDataLoaderIter...         6.21%     502.107ms         6.23%     503.509ms     503.509ms     500.409ms         6.20%     503.536ms     503.536ms             1  
enumerate(DataLoader)#_MultiProcessingDataLoaderIter...         6.19%     500.580ms         6.20%     501.347ms     501.347ms     499.478ms         6.19%     501.366ms     501.366ms             1  
enumerate(DataLoader)#_MultiProcessingDataLoaderIter...         6.02%     486.816ms         6.03%     487.499ms     487.499ms     486.141ms         6.02%     487.516ms     487.516ms             1  
enumerate(DataLoader)#_MultiProcessingDataLoaderIter...         5.95%     481.216ms         5.96%     481.894ms     481.894ms     480.095ms         5.95%     481.944ms     481.944ms             1  
enumerate(DataLoader)#_MultiProcessingDataLoaderIter...         5.92%     478.243ms         5.92%     478.846ms     478.846ms     477.372ms         5.92%     478.866ms     478.866ms             1  
enumerate(DataLoader)#_MultiProcessingDataLoaderIter...         5.79%     467.898ms         5.79%     468.363ms     468.363ms     467.373ms         5.79%     468.433ms     468.433ms             1  
enumerate(DataLoader)#_MultiProcessingDataLoaderIter...         5.76%     465.421ms         5.78%     466.865ms     466.865ms     464.731ms         5.76%     467.026ms     467.026ms             1  
enumerate(DataLoader)#_MultiProcessingDataLoaderIter...         5.74%     463.883ms         5.75%     464.561ms     464.561ms     462.782ms         5.73%     464.576ms     464.576ms             1  
enumerate(DataLoader)#_MultiProcessingDataLoaderIter...         5.68%     459.279ms         5.69%     459.930ms     459.930ms     458.675ms         5.68%     460.029ms     460.029ms             1  
enumerate(DataLoader)#_MultiProcessingDataLoaderIter...         5.67%     458.438ms         5.68%     459.287ms     459.287ms     457.496ms         5.67%     459.304ms     459.304ms             1  
enumerate(DataLoader)#_MultiProcessingDataLoaderIter...         5.59%     451.472ms         5.60%     452.653ms     452.653ms     449.958ms         5.58%     452.676ms     452.676ms             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 8.084s
Self CUDA time total: 8.070s


Just a small update in case anyone might be interested.

I never managed to load more than ~1k samples/s using ImageFolder (with a standard Mve SSD).

I ended up placing the data in a HDF5 file. My hope was that accessing a single file would help. Moreover I think HDF5 saves the sample in a localised physical area in the SSD which improve loading (not so sure though…). Moreover I am resizing and transforming to Tensors when creating the HDF5 file.

I am loading about 7.5k samples/s at the moment (including .to("cuda")).

def create_hdf5_dataset(root_folder: str,
                        hdf5_file: str,
                        target_size: tuple[int, int] = (128, 128),
                        channels: int = 3) -> None:
    """ Create an hdf5 database file from a folder containing images.
    The images are resized to the target_size and stored in the hdf5 file.
    This is useful when your dataset cannot fit in RAM and the data consits of many images"""
    transform = transforms.Compose([
        transforms.Resize(target_size),
        transforms.ToTensor(),
    ])

    # check if data_dir points towards an existing directory
    if not os.path.isdir(root_folder):
        error_string = f"Directory '{root_folder}' not found."
        raise FileNotFoundError(error_string)

    # generate dataset
    try:
        data = ImageFolder(root=root_folder, transform=transform)
        # set chunk size for loading data and appending to hdf5 file
        chunk_size = min(int(len(data) / 10), 1000)
    except Exception as e:
        raise e

    num_images = len(data)

    # open the hdf5 file
    with h5py.File(hdf5_file, "w") as file:
        # Create datasets with chunks for efficient storage
        img_dataset = file.create_dataset("images", shape=(num_images,
                                                           channels,
                                                           target_size[0],
                                                           target_size[1]),
                                          dtype="float32", chunks=None)
        lbl_dataset = file.create_dataset("labels", shape=(num_images,),
                                          dtype="int64", chunks=None)

        loader = DataLoader(data, batch_size=chunk_size, shuffle=False)

        current_index = 0
        # start batch processing (to avoid RAM overflow)
        for images, labels in tqdm(loader):
            chunk_size = images.size(0)

            images = np.array(images)
            labels = np.array(labels)

            img = torch.tensor(images, dtype=torch.float32)
            lbl = torch.tensor(labels, dtype=torch.long)

            img_dataset[current_index:current_index + chunk_size] = img
            lbl_dataset[current_index:current_index + chunk_size] = lbl

            current_index += chunk_size

Loading is done using the following dataset which I stole from this post:

    """ Custom dataset for loading images from an hdf5 file.
    This system lazy loads the images from the hdf5 file."""

class Hdf5Dataset(Dataset):
    def __init__(self, hdf5_file: str, transform: transforms = None) -> None:
        self.hdf5_file = hdf5_file
        self.transform = transform
        self.dataset = None

        with h5py.File(hdf5_file, "r") as file:
            self.length = len(file["images"])

    def __len__(self) -> int:
        return self.length

    def __getitem__(self, idx) -> tuple[torch.Tensor, int]:
        """ Returns a tuple of image and label."""
        if self.dataset is None:
            self.dataset = h5py.File(self.hdf5_file, mode="r", swmr=True)

        image = self.dataset["images"][idx]
        label = self.dataset["labels"][idx]

        if self.transform:
            image = self.transform(image)

        return image, label

if __name__ == "__main__":
    """ Test the hdf5 dataset. vs the image dataset"""

    hdf5_file = r"datasets/birds/train.hdf5"

    batch_size = 32 * 4
    epochs = 4
    num_workers = 6
    prefetch_factor = 10
    persistent_workers = True

    ### loading dataset from hdf5 ###
    print("With hdf5")
    dataset = Hdf5Dataset(hdf5_file)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=num_workers,
                                             persistent_workers=persistent_workers,
                                             prefetch_factor=prefetch_factor,
                                             pin_memory=False)

    for _ in range(epochs):
        for _, (images, labels) in tqdm(enumerate(dataloader)):
            images = images.to("cuda")
            labels = labels.to("cuda")
            pass