Hi everyone,
I have data with size N that is separated into M chunks (N >> M). The data is too big to fit into RAM entirely. As we don’t have random access to data, I was looking for an implementation of a chunk Dataset that inherits IterableDataset which supports multiple workers. I didn’t find anything so I tried to implement it myself:
class ChunkDatasetIterator:
def __init__(self, file_paths):
self.path_permutation = np.random.permutation(file_paths)
self.current_df_index = -1
def __iter__(self):
return self
def __next__(self):
if self.current_df_index == -1:
if self.current_df_index == len(self.path_permutation) - 1:
raise StopIteration
self.current_df_index += 1
self.current_iterator = pd.read_parquet(self.path_permutation[self.current_df_index]).sample(frac=1).iterrows()
try:
result = next(self.current_iterator)[1]
except StopIteration:
if self.current_df_index == len(self.path_permutation) - 1:
raise StopIteration
else:
self.current_df_index += 1
self.current_iterator = pd.read_parquet(self.path_permutation[self.current_df_index]).sample(frac=1).iterrows()
result = next(self.current_iterator)[1]
return result
class ChunkDataset(torch.utils.data.IterableDataset):
def __init__(self, file_paths):
super(ChunkDataset).__init__()
self.file_paths = file_paths
self.dataset_size = 0
for file_path in file_paths:
self.dataset_size += len(pd.read_parquet(file_path))
def __len__(self):
return self.dataset_size
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
return ChunkDatasetIterator(self.file_paths)
else:
return ChunkDatasetIterator(
[elem for ind, elem in enumerate(self.file_paths) if (ind % worker_info.num_workers) == worker_info.id])
This ChunkDataset creates a ChunkDatasetIterator for each worker and splits chunks of data across workers. Then each worker tries to shuffle the order of chunks and shuffle each chunk (two levels of shuffling, the best shuffling I came up with when I don’t have random access to whole data).
This code works very well for my use case. Is this a general good way of handling chunk data for multiple workers? Is there a better way? If this ChunkDataset is a good idea, should I try to make a pull request to the PyTorch project for it?