How to re-iterate an iterable dataset using multiple workers

If I have an iterable dataset (loading tf records of images and bounding boxes), and I wish to repeat the dataset epoch after epoch, my code before realizing a problem was:

for epoch_num in range(epochs):
data_iterator = iter(data_func(*data_args))
for batch in data_iterator:
/// training stuff

The issue with this approach is that my data_func gets a dataloader object where the number of workers is between 1 and 8. It seems each time I execute the line

data_iterator = iter(data_func(*data_args))

I am starting new processes (within a process) which is not desirable. How can I use the processes already available?

You could use persistent_workers=True in the DataLoader creation and iterate it directly via:

for batch in loader:
    ...

which will make sure to keep the workers alive.

1 Like

Assuming I set the flag persistent_workers=True in my call to the dataloader object to get my dataloader, then how do I restart the iterator for each epoch?

The loop over the DataLoader will automatically reuse the iterators.
I haven’t looked into the implementation of persistent_workers, but would assume that manually recreating the iterator wouldn’t work (but you could try it nevertheless and see, if the startup time is slower).

1 Like