Hello, I’m trying to better understand the operation of the persistent_workers option for DataLoader.
My understanding is that the dataloader will not stop the worker processes that have been consuming the dataset after you stop consuming from it. To me this implies that it will save the state of the Dataloader instance and when you come back to consume more batches it will pick up where it left off. The behavior I see when trying to do this is not that. The workers are starting over from the beginning of the dataset instead of picking up where they left off.
So I have a question of what is the intended use of the persistent_workers. Is it to save the state of the dataset to come back and pick up where you left off consuming or is it to keep the workers in memory to reduce file access operations but start fresh each time? Looking at the torch code it seems like the intended behavior is to restart the dataset when iter is called under the conditions of num_workers > 0 and persistent_workers = True. (https://github.com/pytorch/pytorch/blob/main/torch/utils/data/dataloader.py#L435) An interesting observations is that if I comment out the self.iterator.reset(self) the Dataloader functions as I thought it would and picks up where it left off.
My specific use case is that I have many TBs of netcdf files that I want to stream into training a model but I want to custom define an epoch to be every 10 million samples, switch to the validation set for 2 million samples, then start a new epoch using the existing Dataloader object having it pick up where it left off from the previous epoch. I’d appreciate any insight into the intention of worker persistence and if I’m going about this incorrectly, thanks!
Here is a minimal example of how I’m using DataLoader and that it isn’t persisting in the way I consider persistence.
# function to build 10 input netcdf test files where each file is a length 10 array of the same digit 0-9
# ex 1.nc [1,1,1,1,1,1,1,1,1,1]
def build_input_files():
for x in range(10):
ds = xr.DataArray({"data":np.full(10,x)})
ds.to_netcdf(f"{x}s.nc")
# An iterable dataset that reads in a directory of netcdf files
class MixedIter(torch.utils.data.IterableDataset):
def __init__(self, directory: list,):
super(MixedIter).__init__()
self.dir = directory
self.fl = glob(os.path.join(self.dir,"*.nc"))
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
print(f"Worker ID {worker_info.id} started\n")
for f in self.fl:
ds = xr.open_dataarray(f)
for x in torch.from_numpy(ds.data):
yield x
#init_fn to split the directory of nc files and create lists to provide to each worker
def init_fn(worker_id):
worker_info = torch.utils.data.get_worker_info()
dataset = worker_info.dataset
num_workers = worker_info.num_workers
fl = glob(os.path.join(dataset.dir,"*.nc"))
split_size = len(fl) // num_workers
shards = [list(islice(fl,x,len(fl),num_workers))
for x in range(num_workers)]
dataset.fl = shards[worker_info.id]
#Let's iterate through the dataset quick to show it is working
mixed_ds_1 = MixedIter('./')
mixed_dl_1 = torch.utils.data.DataLoader(mixed_ds_1, batch_size=10, num_workers=2, worker_init_fn=init_fn)
for x in mixed_dl_1:
print(x)
Worker ID 0 started
Worker ID 1 started
tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8])
tensor([6, 6, 6, 6, 6, 6, 6, 6, 6, 6])
tensor([4, 4, 4, 4, 4, 4, 4, 4, 4, 4])
tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5])
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2])
tensor([9, 9, 9, 9, 9, 9, 9, 9, 9, 9])
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7])
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3])
#Ok, that looks as expected...
#Now iterate through half of the dataset, break out of the loop, then iterate over it again. This shows
#it is restarting the dataset instead of picking up where it left off. I'd expect the second loop through the same object to start at the tensor full of 0s
mixed_ds_3 = MixedIter('./')
mixed_dl_3 = torch.utils.data.DataLoader(mixed_ds_3, batch_size=10, num_workers=2, worker_init_fn=init_fn, persistent_workers=True)
z=0
for x in mixed_dl_3:
print(x)
if z == 5:
break
z += 1
print("\nsecond set")
for x in mixed_dl_3:
print(x)
Worker ID 0 started
Worker ID 1 started
Worker ID 0 started
Worker ID 1 started
tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8])
tensor([6, 6, 6, 6, 6, 6, 6, 6, 6, 6])
tensor([4, 4, 4, 4, 4, 4, 4, 4, 4, 4])
tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5])
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2])
tensor([9, 9, 9, 9, 9, 9, 9, 9, 9, 9])
second set
tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8])
tensor([6, 6, 6, 6, 6, 6, 6, 6, 6, 6])
tensor([4, 4, 4, 4, 4, 4, 4, 4, 4, 4])
tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5])
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2])
tensor([9, 9, 9, 9, 9, 9, 9, 9, 9, 9])
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7])
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3])