Hey,
I created an IterDataPipe of .PARQUET files.
My train data consists from a few parquet files with more than 100k rows each.
Since I can’t read all the data at once, I iterate over each file and load chunks.
The datapipe is used with Lightning DataModule. The code I wrote based on the documentation:
@functional_datapipe("load_parquet_batch")
class ParquetLoaderIterDataPipe(IterDataPipe):
def __init__(self, source_dp: IterDataPipe, columns: List[str], batch_size: int):
self.source_dp = source_dp
self.columns = columns
self.batch_size = batch_size
self.n_samples = self._get_num_samples()
def _get_num_samples(self) -> int:
return sum([parquet.ParquetFile(path).metadata.num_rows for path in self.source_dp])
def __iter__(self):
for path in self.source_dp:
parquet_file = parquet.ParquetFile(path)
for batch in parquet_file.iter_batches(batch_size=self.batch_size, columns=self.columns):
for _, sample in batch.to_pandas().iterrows():
yield sample
def __len__(self):
if self.n_samples == 0:
raise TypeError # This error will be caught and the training will continue without a progress bar
return self.n_samples
datapipe = dp.iter.FileLister("/path/to/parquet/files/", masks="*.parquet")
datapipe = datapipe.shuffle() # shuffling the files
datapipe = datapipe.load_parquet_batch(COLS_LIST, batch_size=BATCH_SIZE)
datapipe = datapipe.shuffle(buffer_size=BUFFER_SIZE) # shuffling the samples
datapipe = datapipe.sharding_filter()
I still have a few problems when trying to use multi processing (num_workers > 1):
-
Using
sharding_filter
means if there are N workers, then each of them first loads a copy of the entire data. Then it drops most of the data, and keeps a portion of size 1/N. This process is pretty inefficient in both time & memory.
Is there a way to prevent that overhead? Maybe allocating each file to a different worker? -
I’m getting the following warning:
UserWarning: Your IterableDataset has __len__ defined. In combination with multi-process data loading (when num_workers > 1), __len__ could be inaccurate if each worker is not configured independently to avoid having duplicate data.
In case of inaccurate__len__
the training will be crashed because of my early stopping which expects to have the validation metrics - The lightning trainer won’t start the validation as he should.
Is it related to the first issue? How can I compute the exact actual length (excluding the dropped incomplete batches)?
Thanks in advance.