Efficient dataloadfer for sharded dataset

Hi,

I have a bit of an issue thinking of a good design for efficiently loading in a sharded dataset. I’m struggling to map the way the data is laid out on disk onto the PyTorch Dataset/DataLoader abstractions that minimise expensive I/O operations wherever possible (e.g., file open/close).

Please correct and let me know if anything is unclear. English is not my first language and I have a hard time organising my thoughts when writing them down.

Context

I am working with the EarthView dataset, specifically Sentinel-2 data. The dataset consists of a series of parquet files where each file contains around 50 rows of a time series of satellite images. One row corresponds to a single training sample.

This data will be preprocessed before training by resampling all images in a single time series to 384x384 and turning it into a 4D numpy array (T, C, H, W) = (10, 12, 384, 384). These arrays will then be stacked into a 5D array of shape (num_rows, T, C, H, W) and saved to a HDF5 file. Each HDF5 file therefore contains the same data as the corresponding parquet file, but with its contents resampled. This preprocessing step happens once before training, so it does not need to be repeated.

Proposed approach

My current idea is to treat each HDF5 file as a shard and introduce randomness at the file level:

  1. Maintain a list of HDF5 files on disk.
  2. At the start of an epoch, randomly permute this list to determine in which order files are accessed.
  3. Have DataLoader workers iterate over this permuted list and load samples from each file, ideally also permuting the order in which the single samples are accessed within the file.

What I am unsure about is whether the Dataset class should return a single sample with __getitem__ or whether it would make sense to return a whole collection of samples that is contained within an HDF5 file.

Conclusion

To conclude, my specific questions are:

  1. Would it be better to use a Dataset or IterableDataset? My data is finite and does not need to be streamed, so I assume a Dataset is more appropriate.
  2. Should the data be shuffled at file level, within each file, or using a combination of both to balance I/O efficiency and training randomness?
  3. Should __getitem__ return a single sample, or the collection of samples corresponding to a full HDF5 shard?

Let me know if I’m missing important aspects of your use case but my current recommendation would be:

  1. Dataset sounds like the right approach given you know the number of samples and want to shuffle them.
  2. I would vote for the most “randomness” you can achieve while not entirely slowing down your data loading pipeline.
  3. Same as above: whatever makes it easier to maintain the randomness and shuffling.

Thanks for your answer!

Since each shard contains a slightly different amount of samples, I do not know exactly how many samples I have. But, I could always do a pass of the data, get the number of samples from the parquet or HDF5 files the first time around and store that in a json, so I do not have to do that every time I run the code.