Hi,
I’m implementing the multi-process data loading logic for my own Iterable
dataset. However I observed a strange behavior while playing with the second example in the doc here, with implementing worker_init_fn
.
The following is a code snippet to reproduce. MyIterableDataset
and worker_init_fn
are copied from the doc without any modification.
import torch
import math
from torch.utils.data import DataLoader
class MyIterableDataset(torch.utils.data.IterableDataset):
def __init__(self, start, end):
super(MyIterableDataset).__init__()
assert end > start, "this example code only works with end >= start"
self.start = start
self.end = end
def __iter__(self):
return iter(range(self.start, self.end))
# Define a `worker_init_fn` that configures each dataset copy differently
def worker_init_fn(worker_id):
worker_info = torch.utils.data.get_worker_info()
dataset = worker_info.dataset # the dataset copy in this worker process
overall_start = dataset.start
overall_end = dataset.end
# configure the dataset to only process the split workload
per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
worker_id = worker_info.id
dataset.start = overall_start + worker_id * per_worker
dataset.end = min(dataset.start + per_worker, overall_end)
if __name__ == '__main__':
ds = MyIterableDataset(start=0, end=500)
dl = DataLoader(
dataset=ds, batch_size=100, num_workers=2, worker_init_fn=worker_init_fn,
)
for e in dl:
print(e.shape)
Running this snippet gives the result:
torch.Size([100])
torch.Size([100])
torch.Size([100])
torch.Size([100])
torch.Size([50])
torch.Size([50])
We can see that the last 100
examples are split into two batches of 50
examples.
Similarly, when I changed num_workers=3
, I got
torch.Size([100])
torch.Size([100])
torch.Size([100])
torch.Size([67])
torch.Size([67])
torch.Size([66])
Is this a bug or under expectation?
Thanks