Iterable pytorch dataset with multiple workers

Thanks for the reply!

Really good material you linked to, I think I have solved it. Can you double-check my logic? I tested it and works good so far.

I replace the dataset with this new definition:

class CustomIterableDatasetv1(IterableDataset):

    def __init__(self, filename):

        #Store the filename in object's memory
        self.filename = filename

    def preprocess(self, text):

        ### Do something with text here
        text_pp = text.lower().strip()
        ###

        return text_pp

    def line_mapper(self, line):
        
        #Splits the line into text and label and applies preprocessing to the text
        text, label = line.split('-')
        text = self.preprocess(text)

        return text, label


    def __iter__(self):
        worker_total_num = torch.utils.data.get_worker_info().num_workers
        worker_id = torch.utils.data.get_worker_info().id
        #Create an iterator
        file_itr = open(self.filename)

        #Map each element using the line_mapper
        mapped_itr = map(self.line_mapper, file_itr)
        
        #Add multiworker functionality
        mapped_itr = itertools.islice(mapped_itr, worker_id, None, worker_total_num)

        return mapped_itr

I make use of your suggestion and access get_worker_info() to know the total number of workers and current worker. I return a sliced version of the dataloader where each worker will only return the samples that correspond to it. Each worker will still iterate over the full dataset, just that it wont return samples other workers are returning.

1 Like