How to have a fluent dataloader?

I found that dataloader need to be reinitialized between two epochs which will cost several seconds. I want to have a fluent dataloader which could yield data batches fluently all the time.

In my case, the training dataset consists of 50 3D CT images. Because loading 3D CT images from disk is time consuming, the first data generated by dataloader normally cost 10 seconds. And because dataloader would be reinitialized before the start of next epoch, the GPU always have to wait for data at the first of each epoch.

I know set num_worker can relieve this problem to some extent, it can not solve this problem.
I also know that set prefetch can relieve this problem to some extent, it can not solve this problem.
Because the above 2 methods can only make sure the fluency in one epoch. What I want is to smooth the data flow between two epochs.

My idea is to have 2 queues and 2 dataloaders. Then I will create 2 processes / threads for the 2 dataloaders. Each dataloader generate data and put the data to its queue in different process. In the first epoch, we get the data from queue_1. And after the data in queue_1 is ran out, we get the data froom queue_2 which has already been filled by dataloader_2 in background. And so on.

But it seems that dataloader does not support my idea. My code is:

def fluent_generator(dataloader, dataloader2, keys = ("image", "label")):
    data_q = Queue(maxsize=10)
    data_q2 = Queue(maxsize=10)

    def start_dataloader(dl, q):
        while True:
            for data in dl:
                x_pps = data[keys[0]]
                y_pps = data[keys[1]]
                data = (x, y)
                if q==1:
                    data_q.put(data, timeout=100)
                else:
                    data_q2.put(data, timeout=100)

    p1 = multiprocessing.Process(target=start_dataloader, args=(dataloader, 1, ))
    p2 = multiprocessing.Process(target=start_dataloader, args=(dataloader2, 2, ))
    p1.start()
    p2.start()

    use_q2 = False  # use second queue to get data
    while True:
        if len(keys)==2:
            if data_q.empty() and data_q2.empty():
                continue
            else:
                if data_q.empty() or use_q2:
                    q = data_q2
                    use_q2 = True
                else:
                    q = data_q
                    use_q2 = False
            data = q.get(timeout=100)
            yield data

The error is:

  File "/home/jjia/data/monai/train_mtnet.py", line 305, in start_dataloader
    for data in dl:
  File "/home/jjia/.conda/envs/py37/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 352, in __iter__
    return self._get_iterator()
  File "/home/jjia/.conda/envs/py37/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 294, in _get_iterator
    return _MultiProcessingDataLoaderIter(self)
  File "/home/jjia/.conda/envs/py37/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 814, in __init__
    torch.cuda.current_device(),
  File "/home/jjia/.conda/envs/py37/lib/python3.7/site-packages/torch/cuda/__init__.py", line 366, in current_device
    _lazy_init()
  File "/home/jjia/.conda/envs/py37/lib/python3.7/site-packages/torch/cuda/__init__.py", line 164, in _lazy_init
    "Cannot re-initialize CUDA in forked subprocess. " + msg)
RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x2aab2c21db90>
Traceback (most recent call last):
  File "/home/jjia/.conda/envs/py37/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1206, in __del__
    self._shutdown_workers()
  File "/home/jjia/.conda/envs/py37/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1177, in _shutdown_workers
    if self._persistent_workers or self._workers_status[worker_id]:
AttributeError: '_MultiProcessingDataLoaderIter' object has no attribute '_workers_status'
Traceback (most recent call last):
  File "/data/jjia/softwares/pycharm-community-2020.2.1/plugins/python-ce/helpers/pydev/pydevd.py", line 2132, in main
    globals = debugger.run(setup['file'], None, None, is_module)
KeyboardInterrupt

Process finished with exit code 1

Can anyone give me more ideas about how to generate a fluent dataloader which is fluent between epochs as it is in one epoch?

You could use the new persistent_workers argument in the DataLoader construction, which will not shutdown the workers between epochs.