DataLoader persistent_workers Usage

Hello, I’m trying to better understand the operation of the persistent_workers option for DataLoader.

My understanding is that the dataloader will not stop the worker processes that have been consuming the dataset after you stop consuming from it. To me this implies that it will save the state of the Dataloader instance and when you come back to consume more batches it will pick up where it left off. The behavior I see when trying to do this is not that. The workers are starting over from the beginning of the dataset instead of picking up where they left off.

So I have a question of what is the intended use of the persistent_workers. Is it to save the state of the dataset to come back and pick up where you left off consuming or is it to keep the workers in memory to reduce file access operations but start fresh each time? Looking at the torch code it seems like the intended behavior is to restart the dataset when iter is called under the conditions of num_workers > 0 and persistent_workers = True. (https://github.com/pytorch/pytorch/blob/main/torch/utils/data/dataloader.py#L435) An interesting observations is that if I comment out the self.iterator.reset(self) the Dataloader functions as I thought it would and picks up where it left off.

My specific use case is that I have many TBs of netcdf files that I want to stream into training a model but I want to custom define an epoch to be every 10 million samples, switch to the validation set for 2 million samples, then start a new epoch using the existing Dataloader object having it pick up where it left off from the previous epoch. I’d appreciate any insight into the intention of worker persistence and if I’m going about this incorrectly, thanks!

Here is a minimal example of how I’m using DataLoader and that it isn’t persisting in the way I consider persistence.

# function to build 10 input netcdf test files where each file is a length 10 array of the same digit 0-9
# ex 1.nc [1,1,1,1,1,1,1,1,1,1]
def build_input_files():
    for x in range(10):
        ds = xr.DataArray({"data":np.full(10,x)})
        ds.to_netcdf(f"{x}s.nc")

# An iterable dataset that reads in a directory of netcdf files

class MixedIter(torch.utils.data.IterableDataset):
    def __init__(self, directory: list,):
        super(MixedIter).__init__()
        self.dir = directory
        self.fl = glob(os.path.join(self.dir,"*.nc"))
    
    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        print(f"Worker ID {worker_info.id} started\n")
        for f in self.fl:
            ds = xr.open_dataarray(f)
            for x in torch.from_numpy(ds.data):
                yield x
                
#init_fn to split the directory of nc files and create lists to provide to each worker 
def init_fn(worker_id):
    worker_info = torch.utils.data.get_worker_info()
    dataset = worker_info.dataset
    num_workers = worker_info.num_workers
    fl = glob(os.path.join(dataset.dir,"*.nc"))
    split_size = len(fl) // num_workers
    shards = [list(islice(fl,x,len(fl),num_workers))
            for x in range(num_workers)]
    dataset.fl = shards[worker_info.id]

#Let's iterate through the dataset quick to show it is working
mixed_ds_1 = MixedIter('./')
mixed_dl_1 = torch.utils.data.DataLoader(mixed_ds_1, batch_size=10, num_workers=2, worker_init_fn=init_fn)
for x in mixed_dl_1:
    print(x)


Worker ID 0 started

Worker ID 1 started

tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8])
tensor([6, 6, 6, 6, 6, 6, 6, 6, 6, 6])
tensor([4, 4, 4, 4, 4, 4, 4, 4, 4, 4])
tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5])
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2])
tensor([9, 9, 9, 9, 9, 9, 9, 9, 9, 9])
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7])
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3])

#Ok, that looks as expected...
#Now iterate through half of the dataset, break out of the loop, then iterate over it again. This shows
#it is restarting the dataset instead of picking up where it left off. I'd expect the second loop through the same object to start at the tensor full of 0s

mixed_ds_3 = MixedIter('./')
mixed_dl_3 = torch.utils.data.DataLoader(mixed_ds_3, batch_size=10, num_workers=2, worker_init_fn=init_fn, persistent_workers=True)
z=0
for x in mixed_dl_3:
    print(x)
    if z == 5:
        break
    z += 1
print("\nsecond set")
for x in mixed_dl_3:
    print(x)

Worker ID 0 started
Worker ID 1 started


Worker ID 0 started
Worker ID 1 started


tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8])
tensor([6, 6, 6, 6, 6, 6, 6, 6, 6, 6])
tensor([4, 4, 4, 4, 4, 4, 4, 4, 4, 4])
tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5])
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2])
tensor([9, 9, 9, 9, 9, 9, 9, 9, 9, 9])

second set
tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8])
tensor([6, 6, 6, 6, 6, 6, 6, 6, 6, 6])
tensor([4, 4, 4, 4, 4, 4, 4, 4, 4, 4])
tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5])
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2])
tensor([9, 9, 9, 9, 9, 9, 9, 9, 9, 9])
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7])
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3])

Using persistent_workers=True should avoid deleting the workers once the DataLoader is exhausted. Instead, iterating multiple epochs will reuse the same workers and will continue to prefetch the next batch(es) without recreating the workers and thus re-initializing the Dataset.
Your code shows separate loops which will call into the __iter__ method so I assume the output is expected. However, no new worker processes should be created.

Thanks for the quick response. After reading the source and discussing it with someone else that’s what we were thinking was the intention of the persistent_workers and not how I initially thought/hoped it would work. Though the latter seems possible with some small changes to torch.utils.data.DataLoader. I can architect my training/validation loops to take that into consideration with a non-repeating dataset.

Do you suggest to set persistent_workers = True when dealing with a Datalaoder that have the sample set to a DistributedSampler for a DDP application?

Using persistent workers could speed up your training if the initialization of all workers takes a lot of time and the used sampler is irrelevant. If you see a benefit using it for your DDP use case, it sounds like a valid usage for it.

1 Like