Iterable pytorch dataset with multiple workers

Hi!

So I have a text file bigger than my ram memory, I would like to create a dataset in PyTorch that reads line by line, so I don’t have to load it all at once in memory. I found pytorch IterableDataset as potential solution for my problem. It only works as expected when using 1 worker, if using more than one worker it will create duplicate recods. Let me show you an example:

Having a testfile.txt containing:

0 - Dummy line
1 - Dummy line
2 - Dummy line
3 - Dummy line
4 - Dummy line
5 - Dummy line
6 - Dummy line
7 - Dummy line
8 - Dummy line
9 - Dummy line

Defining a IterableDataset:

class CustomIterableDatasetv1(IterableDataset):

    def __init__(self, filename):

        #Store the filename in object's memory
        self.filename = filename

    def preprocess(self, text):

        ### Do something with text here
        text_pp = text.lower().strip()
        ###

        return text_pp

    def line_mapper(self, line):
        
        #Splits the line into text and label and applies preprocessing to the text
        text, label = line.split('-')
        text = self.preprocess(text)

        return text, label


    def __iter__(self):

        #Create an iterator
        file_itr = open(self.filename)

        #Map each element using the line_mapper
        mapped_itr = map(self.line_mapper, file_itr)
        
        return mapped_itr

We can now test it:

base_dataset = CustomIterableDatasetv1("testfile.txt")
#Wrap it around a dataloader
dataloader = DataLoader(base_dataset, batch_size = 1, num_workers = 1)
for X, y in dataloader:
    print(X,y)

It outputs:



('0',) (' Dummy line\n',)
('1',) (' Dummy line\n',)
('2',) (' Dummy line\n',)
('3',) (' Dummy line\n',)
('4',) (' Dummy line\n',)
('5',) (' Dummy line\n',)
('6',) (' Dummy line\n',)
('7',) (' Dummy line\n',)
('8',) (' Dummy line\n',)
('9',) (' Dummy line',)

That is correct. But If I change the number of workers to 2 the output becomes

('0',) (' Dummy line\n',)
('0',) (' Dummy line\n',)
('1',) (' Dummy line\n',)
('1',) (' Dummy line\n',)
('2',) (' Dummy line\n',)
('2',) (' Dummy line\n',)
('3',) (' Dummy line\n',)
('3',) (' Dummy line\n',)
('4',) (' Dummy line\n',)
('4',) (' Dummy line\n',)
('5',) (' Dummy line\n',)
('5',) (' Dummy line\n',)
('6',) (' Dummy line\n',)
('6',) (' Dummy line\n',)
('7',) (' Dummy line\n',)
('7',) (' Dummy line\n',)
('8',) (' Dummy line\n',)
('8',) (' Dummy line\n',)
('9',) (' Dummy line',)
('9',) (' Dummy line',)

Which is incorrect, as is creating duplicates of each sample per worker in the data loader.

Is there a way to solve this issue with pytorch? So a dataloader can be created to not load all file in memory with support for multiple workers.

The docs explain this behavior and suggest to use the worker information:

When a subclass is used with DataLoader, each item in the dataset will be yielded from the DataLoader iterator. When num_workers > 0, each worker process will have a different copy of the dataset object, so it is often desired to configure each copy independently to avoid having duplicate data returned from the workers. get_worker_info(), when called in a worker process, returns information about the worker. It can be used in either the dataset’s __iter__() method or the DataLoader ‘s worker_init_fn option to modify each copy’s behavior.

2 Likes

Thanks for the reply!

Really good material you linked to, I think I have solved it. Can you double-check my logic? I tested it and works good so far.

I replace the dataset with this new definition:

class CustomIterableDatasetv1(IterableDataset):

    def __init__(self, filename):

        #Store the filename in object's memory
        self.filename = filename

    def preprocess(self, text):

        ### Do something with text here
        text_pp = text.lower().strip()
        ###

        return text_pp

    def line_mapper(self, line):
        
        #Splits the line into text and label and applies preprocessing to the text
        text, label = line.split('-')
        text = self.preprocess(text)

        return text, label


    def __iter__(self):
        worker_total_num = torch.utils.data.get_worker_info().num_workers
        worker_id = torch.utils.data.get_worker_info().id
        #Create an iterator
        file_itr = open(self.filename)

        #Map each element using the line_mapper
        mapped_itr = map(self.line_mapper, file_itr)
        
        #Add multiworker functionality
        mapped_itr = itertools.islice(mapped_itr, worker_id, None, worker_total_num)

        return mapped_itr

I make use of your suggestion and access get_worker_info() to know the total number of workers and current worker. I return a sliced version of the dataloader where each worker will only return the samples that correspond to it. Each worker will still iterate over the full dataset, just that it wont return samples other workers are returning.

2 Likes

Good to hear it’s working and your explanation sounds also reasonable! :slight_smile:

I have been looking for this functionality to load very large datasets and it really helped me, thanks. One issue I found is that while the modified code has resolved the problem related to redundancy in data output, there is no difference in the execution speed. I am assuming it is because all K workers do the exact same job (i.e., in your case iterates through the whole dataset, reads the text and label). So, increasing the number of workers simply does redundant processing, but they just don’t return those redundancy. Let me know if I understood this right and if yes, is there an updated code to rectify this issue?

Did you try to create iterators contiguously, such as each worker receives smaller iterator as a contiguous sequence of observations, as is described here?
By the way, proposed solution does not iterate through the whole dataset nor does the redundant processing. It creates for each worker a separate iterator which yields each num_workers-th element of the original dataset, starting from worker_id observation.

I have a related question on using multiple workers with pytorch dataset. I am using a similar example as in the docs:

import math
import torch
from torch.utils.data import IterableDataset, DataLoader

class MyDataset(IterableDataset):
    def __init__(self, start, end):
        super(MyDataset).__init__()
        self.start = start
        self.end = end

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None:
            iter_start = self.start
            iter_end = self.end
        else:
            per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
            worker_id = worker_info.id
            iter_start = self.start + worker_info.id * per_worker
            iter_end = min(iter_start + per_worker, self.end)
        for i in range(iter_start, iter_end):
            yield i

da = MyDataset(0, 16)
dataloader = DataLoader(da, batch_size=2, num_workers=1)

for i, data in enumerate(dataloader):
    print (i, data)

When I run it, the process exits unexpectedly and here is the corresponding traceback:

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Users/arun/Applications/miniconda3/envs/dc/lib/python3.9/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/Users/arun/Applications/miniconda3/envs/dc/lib/python3.9/multiprocessing/spawn.py", line 125, in _main
    prepare(preparation_data)
  File "/Users/arun/Applications/miniconda3/envs/dc/lib/python3.9/multiprocessing/spawn.py", line 236, in prepare
    _fixup_main_from_path(data['init_main_from_path'])
  File "/Users/arun/Applications/miniconda3/envs/dc/lib/python3.9/multiprocessing/spawn.py", line 287, in _fixup_main_from_path
    main_content = runpy.run_path(main_path,
  File "/Users/arun/Applications/miniconda3/envs/dc/lib/python3.9/runpy.py", line 288, in run_path
    return _run_module_code(code, init_globals, run_name,
  File "/Users/arun/Applications/miniconda3/envs/dc/lib/python3.9/runpy.py", line 97, in _run_module_code
    _run_code(code, mod_globals, init_globals,
  File "/Users/arun/Applications/miniconda3/envs/dc/lib/python3.9/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/Users/arun/proj/dfs/chemberta3/distributed/td/datasets/two.py", line 28, in <module>
    for i, data in enumerate(dataloader):
  File "/Users/arun/Applications/miniconda3/envs/dc/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 441, in __iter__
    return self._get_iterator()
  File "/Users/arun/Applications/miniconda3/envs/dc/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 388, in _get_iterator
    return _MultiProcessingDataLoaderIter(self)
  File "/Users/arun/Applications/miniconda3/envs/dc/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1042, in __init__
    w.start()
  File "/Users/arun/Applications/miniconda3/envs/dc/lib/python3.9/multiprocessing/process.py", line 121, in start
    self._popen = self._Popen(self)
  File "/Users/arun/Applications/miniconda3/envs/dc/lib/python3.9/multiprocessing/context.py", line 224, in _Popen
    return _default_context.get_context().Process._Popen(process_obj)
  File "/Users/arun/Applications/miniconda3/envs/dc/lib/python3.9/multiprocessing/context.py", line 284, in _Popen
    return Popen(process_obj)
  File "/Users/arun/Applications/miniconda3/envs/dc/lib/python3.9/multiprocessing/popen_spawn_posix.py", line 32, in __init__
    super().__init__(process_obj)
  File "/Users/arun/Applications/miniconda3/envs/dc/lib/python3.9/multiprocessing/popen_fork.py", line 19, in __init__
    self._launch(process_obj)
  File "/Users/arun/Applications/miniconda3/envs/dc/lib/python3.9/multiprocessing/popen_spawn_posix.py", line 42, in _launch
    prep_data = spawn.get_preparation_data(process_obj._name)
  File "/Users/arun/Applications/miniconda3/envs/dc/lib/python3.9/multiprocessing/spawn.py", line 154, in get_preparation_data
    _check_not_importing_main()
  File "/Users/arun/Applications/miniconda3/envs/dc/lib/python3.9/multiprocessing/spawn.py", line 134, in _check_not_importing_main
    raise RuntimeError('''
RuntimeError: 
        An attempt has been made to start a new process before the
        current process has finished its bootstrapping phase.

        This probably means that you are not using fork to start your
        child processes and you have forgotten to use the proper idiom
        in the main module:

            if __name__ == '__main__':
                freeze_support()
                ...

        The "freeze_support()" line can be omitted if the program
        is not going to be frozen to produce an executable.
Traceback (most recent call last):
  File "/Users/arun/Applications/miniconda3/envs/dc/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1132, in _try_get_data
    data = self._data_queue.get(timeout=timeout)
  File "/Users/arun/Applications/miniconda3/envs/dc/lib/python3.9/multiprocessing/queues.py", line 113, in get
    if not self._poll(timeout):
  File "/Users/arun/Applications/miniconda3/envs/dc/lib/python3.9/multiprocessing/connection.py", line 257, in poll
    return self._poll(timeout)
  File "/Users/arun/Applications/miniconda3/envs/dc/lib/python3.9/multiprocessing/connection.py", line 424, in _poll
    r = wait([self], timeout)
  File "/Users/arun/Applications/miniconda3/envs/dc/lib/python3.9/multiprocessing/connection.py", line 931, in wait
    ready = selector.select(timeout)
  File "/Users/arun/Applications/miniconda3/envs/dc/lib/python3.9/selectors.py", line 416, in select
    fd_event_list = self._selector.poll(timeout)
  File "/Users/arun/Applications/miniconda3/envs/dc/lib/python3.9/site-packages/torch/utils/data/_utils/signal_handling.py", line 66, in handler
    _error_if_any_worker_fails()
RuntimeError: DataLoader worker (pid 3260) exited unexpectedly with exit code 1. Details are lost due to multiprocessing. Rerunning with num_workers=0 may give better error trace.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/arun/proj/dfs/chemberta3/distributed/td/datasets/two.py", line 28, in <module>
    for i, data in enumerate(dataloader):
  File "/Users/arun/Applications/miniconda3/envs/dc/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 633, in __next__
    data = self._next_data()
  File "/Users/arun/Applications/miniconda3/envs/dc/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1328, in _next_data
    idx, data = self._get_data()
  File "/Users/arun/Applications/miniconda3/envs/dc/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1294, in _get_data
    success, data = self._try_get_data()
  File "/Users/arun/Applications/miniconda3/envs/dc/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1145, in _try_get_data
    raise RuntimeError('DataLoader worker (pid(s) {}) exited unexpectedly'.format(pids_str)) from e
RuntimeError: DataLoader worker (pid(s) 3260) exited unexpectedly

I am not sure what is going here - is there anything wrong with the above code snippet? It works when num_workers=0 and the issue happens when num_workers > 0.