Thanks for the reply!
Really good material you linked to, I think I have solved it. Can you double-check my logic? I tested it and works good so far.
I replace the dataset with this new definition:
class CustomIterableDatasetv1(IterableDataset):
def __init__(self, filename):
#Store the filename in object's memory
self.filename = filename
def preprocess(self, text):
### Do something with text here
text_pp = text.lower().strip()
###
return text_pp
def line_mapper(self, line):
#Splits the line into text and label and applies preprocessing to the text
text, label = line.split('-')
text = self.preprocess(text)
return text, label
def __iter__(self):
worker_total_num = torch.utils.data.get_worker_info().num_workers
worker_id = torch.utils.data.get_worker_info().id
#Create an iterator
file_itr = open(self.filename)
#Map each element using the line_mapper
mapped_itr = map(self.line_mapper, file_itr)
#Add multiworker functionality
mapped_itr = itertools.islice(mapped_itr, worker_id, None, worker_total_num)
return mapped_itr
I make use of your suggestion and access get_worker_info() to know the total number of workers and current worker. I return a sliced version of the dataloader where each worker will only return the samples that correspond to it. Each worker will still iterate over the full dataset, just that it wont return samples other workers are returning.