Multi Processing Parquet IterDataPipe

Hey,

I created an IterDataPipe of .PARQUET files.
My train data consists from a few parquet files with more than 100k rows each.
Since I can’t read all the data at once, I iterate over each file and load chunks.

The datapipe is used with Lightning DataModule. The code I wrote based on the documentation:

@functional_datapipe("load_parquet_batch")
class ParquetLoaderIterDataPipe(IterDataPipe):
    def __init__(self, source_dp: IterDataPipe, columns: List[str], batch_size: int):
        self.source_dp = source_dp
        self.columns = columns
        self.batch_size = batch_size
        self.n_samples = self._get_num_samples()

    def _get_num_samples(self) -> int:
        return sum([parquet.ParquetFile(path).metadata.num_rows for path in self.source_dp])

    def __iter__(self):
        for path in self.source_dp:
            parquet_file = parquet.ParquetFile(path)
            for batch in parquet_file.iter_batches(batch_size=self.batch_size, columns=self.columns):
                for _, sample in batch.to_pandas().iterrows():
                    yield sample

    def __len__(self):
        if self.n_samples == 0:
            raise TypeError # This error will be caught and the training will continue without a progress bar
        return self.n_samples


datapipe = dp.iter.FileLister("/path/to/parquet/files/", masks="*.parquet")
datapipe = datapipe.shuffle()  # shuffling the files
datapipe = datapipe.load_parquet_batch(COLS_LIST, batch_size=BATCH_SIZE)
datapipe = datapipe.shuffle(buffer_size=BUFFER_SIZE)  # shuffling the samples
datapipe = datapipe.sharding_filter()

I still have a few problems when trying to use multi processing (num_workers > 1):

  1. Using sharding_filter means if there are N workers, then each of them first loads a copy of the entire data. Then it drops most of the data, and keeps a portion of size 1/N. This process is pretty inefficient in both time & memory.
    Is there a way to prevent that overhead? Maybe allocating each file to a different worker?

  2. I’m getting the following warning:
    UserWarning: Your IterableDataset has __len__ defined. In combination with multi-process data loading (when num_workers > 1), __len__ could be inaccurate if each worker is not configured independently to avoid having duplicate data.
    In case of inaccurate __len__ the training will be crashed because of my early stopping which expects to have the validation metrics - The lightning trainer won’t start the validation as he should.
    Is it related to the first issue? How can I compute the exact actual length (excluding the dropped incomplete batches)?

Thanks in advance.

1 Like