Hi, I wanted to know if it is possible to use a torch.multiprocessing.SimpleQueue
other than the two examples provided in the IterableDataset
documentation to split the work among the num_worker
. Let me set up a dummy example. Assume the following map-style Dataset
:
class MyMapDataset(Dataset):
def __init__(self, length: int = 100):
self.length = length
def __getitem__(self, index) -> torch.Tensor:
return torch.arange(0, index + 1)
def __len__(self):
return self.length
Now we have an iterable-style dataset that uses the MyMapDataset
:
class MyIterableDataset(IterableDataset):
def __init__(self, dataset: MyMapDataset):
self.dataset = dataset
self.queue = SimpleQueue()
# Put the indices to process in a shared queue.
for index in range(len(self.dataset)):
self.queue.put(index)
def __iter__(self):
while not self.queue.empty():
yield from self.dataset[self.queue.get()]
And finally to run:
map_dataset = MyMapDataset()
iterable_dataset = MyIterableDataset(dataset=map_dataset)
dl = DataLoader(iterable_dataset, num_workers=4)
for sample in dl:
pass
Will the SimpleQueue
be created on the main process and then all the workers will have access to it or will this just create num_workers
instances of a SimpleQueue
?