Best practice of implementing datapipe loading with both multiprocessing and Python threading

Hi all,
I’m trying to develop a data loading and training pipeline using DataPipes. Our pipeline consists of three stages:

  1. A CPU-compute-intensive stage that computes the indices to fetch necessary features from given the minibatch. I think this is best done with multiprocessing where I spawn multiple subprocesses to run this stage.
  2. An IO-bound stage that fetches the features from the given indices and transfers them to GPU. I would like to put this stage in a separate Python thread in the main process. It then issues I/O requests and initiates GPU transfers to maximize overlap.
  3. A GPU-compute-intensive stage that updates the model given the features on GPU.

The question is: what is the best practice to implement such a pipeline? For now, I can think of two options:

  • Wrap the first stage as an IterDataPipe whose __iter__ function yields elements from a PyTorch multiprocessing DataLoader. Then wrap the second stage as an IterDataPipe that yields elements from a thread running in the background. I’m not sure if allowing a DataPipe object to spawn processes/threads is a good idea.
  • Write my own ReadingServiceInterface, but it seems that it is only available in DataLoader2, while I would like to use it together with other third party tools like PyTorch Lightning (which only supports DataLoader).

Happy to discuss further.