I’m having a hard time understanding the example iterable dataset in the docs, and how to apply it to my own code. The example code is:
>>> class MyIterableDataset(torch.utils.data.IterableDataset):
... def __init__(self, start, end):
... super(MyIterableDataset).__init__()
... assert end > start, "this example code only works with end >= start"
... self.start = start
... self.end = end
...
... def __iter__(self):
... worker_info = torch.utils.data.get_worker_info()
... if worker_info is None: # single-process data loading, return the full iterator
... iter_start = self.start
... iter_end = self.end
... else: # in a worker process
... # split workload
... per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
... worker_id = worker_info.id
... iter_start = self.start + worker_id * per_worker
... iter_end = min(iter_start + per_worker, self.end)
... return iter(range(iter_start, iter_end))
...
In the IterableDataset
I have created it loops through a set of files (the data is already randomly shuffled, so that is not an issue), and yields one row at a time:
def __iter__(self) -> Iterator[dict]:
for file in self.files:
if self.compressed:
fopen = BZ2File(filename=file, mode='r')
else:
fopen = open(file=file, mode='r')
with fopen as f:
for row in f:
data = json.loads(row)
tokens = data['words']
tokens = tokens[:self.max_length]
random.shuffle(tokens)
indices, mask = self.tokens_to_indices(tokens)
item = {
'src': indices,
'mask': mask,
'label': data['label']
}
yield item
I’m not sure how it would make sense to return an iterator in this scenario. Any advice would be appreciated!