Can Dataloader use the same child processes for every epoch?

I use PyTorch in Python2 to do some deep learning tasks.

In my Dataset class, I use cuda.jit of Numba to preprocess some 3D point cloud data. When I set num_workers=0 in Dataloader, everything is OK, but the speed is slow.

When I set num_workers > 0, there will be CudaSupportError. Because when forking process, if CUDA has been used in main process, the CUDA will be initialized in sub process again and this will trigger the error.

According to https://pytorch.org/docs/stable/notes/multiprocessing.html#multiprocessing-cuda-note, the best solution is to set start method to “spawn”. However, my whole project is written in Python2, which does not support mulitprocessing.set_start_method(‘spawn’).

Then I tried to initialize cuda-related Tensors after the dataloader creates the subprocesses, i.e., in the first batch. It worked for the first epoch, but the same error occurs in the second epoch. Then I found that the dataloader creates new subprocesses for every epoch, and destroys them when the loop exits.

I wonder that if I can use the same subprocesses in the dataloader for every epoch. Then I can initialize cuda when those subprocesses are created in the beginning, and avoid the CUDA re-initialization errors.

Hi,

There is already an issue open aiming at doing this here: https://github.com/pytorch/pytorch/issues/15849

1 Like

Thanks! I have implemented this function using the method in the link you provide, as below:

import torch
import torch.utils.data


class _RepeatSampler(object):

    def __init__(self, sampler):
        self.sampler = sampler

    def __iter__(self):
        while True:
            for item in iter(self.sampler):
                yield item 


class DataLoader(torch.utils.data.dataloader.DataLoader):

    def __init__(self, *args, **kwargs):
        super(DataLoader, self).__init__(*args, **kwargs)
        self.batch_sampler = _RepeatSampler(self.batch_sampler)
        self.iterator = super(DataLoader, self).__iter__()

    def __len__(self):
        return len(self.batch_sampler.sampler)

    def __iter__(self):
        for i in range(len(self)):
            yield next(self.iterator)