An IterableDataset implementation for chunked data

Hi everyone,
I have data with size N that is separated into M chunks (N >> M). The data is too big to fit into RAM entirely. As we don’t have random access to data, I was looking for an implementation of a chunk Dataset that inherits IterableDataset which supports multiple workers. I didn’t find anything so I tried to implement it myself:

class ChunkDatasetIterator:
    def __init__(self, file_paths):
        self.path_permutation = np.random.permutation(file_paths)
        self.current_df_index = -1

    def __iter__(self):
        return self

    def __next__(self):
        if self.current_df_index == -1:
            if self.current_df_index == len(self.path_permutation) - 1:
                raise StopIteration
            self.current_df_index += 1
            self.current_iterator = pd.read_parquet(self.path_permutation[self.current_df_index]).sample(frac=1).iterrows()
            
        try:
            result = next(self.current_iterator)[1]
        except StopIteration:
            if self.current_df_index == len(self.path_permutation) - 1:
                raise StopIteration
            else:
                self.current_df_index += 1
                self.current_iterator = pd.read_parquet(self.path_permutation[self.current_df_index]).sample(frac=1).iterrows()
                result = next(self.current_iterator)[1]
        
        return result

class ChunkDataset(torch.utils.data.IterableDataset):
    def __init__(self, file_paths):
        super(ChunkDataset).__init__()
        self.file_paths = file_paths
        self.dataset_size = 0
        for file_path in file_paths:
            self.dataset_size += len(pd.read_parquet(file_path))

    def __len__(self):
        return self.dataset_size

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None:
            return ChunkDatasetIterator(self.file_paths)
        else:
            return ChunkDatasetIterator(
                [elem for ind, elem in enumerate(self.file_paths) if (ind % worker_info.num_workers) == worker_info.id])

This ChunkDataset creates a ChunkDatasetIterator for each worker and splits chunks of data across workers. Then each worker tries to shuffle the order of chunks and shuffle each chunk (two levels of shuffling, the best shuffling I came up with when I don’t have random access to whole data).

This code works very well for my use case. Is this a general good way of handling chunk data for multiple workers? Is there a better way? If this ChunkDataset is a good idea, should I try to make a pull request to the PyTorch project for it?

Thanks for sharing this implementation!
I think you could start with a feature request on GitHub and explain your use case as well as your implementation. Currently you are using 3rd party modules such as pandas, which I assume could be removed to allow for a more general use case.
Once the feature request is done the code owners will take a look at it :slight_smile:

1 Like

Thank you so much @ptrblck ! I’ll soon submit an feature request on this.