Multi-process data loading with iterable dataset

Hi,

I’m implementing the multi-process data loading logic for my own Iterable dataset. However I observed a strange behavior while playing with the second example in the doc here, with implementing worker_init_fn.

The following is a code snippet to reproduce. MyIterableDataset and worker_init_fn are copied from the doc without any modification.

import torch
import math
from torch.utils.data import DataLoader


class MyIterableDataset(torch.utils.data.IterableDataset):
    def __init__(self, start, end):
        super(MyIterableDataset).__init__()
        assert end > start, "this example code only works with end >= start"
        self.start = start
        self.end = end

    def __iter__(self):
        return iter(range(self.start, self.end))


# Define a `worker_init_fn` that configures each dataset copy differently
def worker_init_fn(worker_id):
    worker_info = torch.utils.data.get_worker_info()
    dataset = worker_info.dataset  # the dataset copy in this worker process
    overall_start = dataset.start
    overall_end = dataset.end
    # configure the dataset to only process the split workload
    per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
    worker_id = worker_info.id
    dataset.start = overall_start + worker_id * per_worker
    dataset.end = min(dataset.start + per_worker, overall_end)


if __name__ == '__main__':
    ds = MyIterableDataset(start=0, end=500)

    dl = DataLoader(
        dataset=ds, batch_size=100, num_workers=2, worker_init_fn=worker_init_fn,
    )

    for e in dl:
        print(e.shape)

Running this snippet gives the result:

torch.Size([100])
torch.Size([100])
torch.Size([100])
torch.Size([100])
torch.Size([50])
torch.Size([50])

We can see that the last 100 examples are split into two batches of 50 examples.

Similarly, when I changed num_workers=3, I got

torch.Size([100])
torch.Size([100])
torch.Size([100])
torch.Size([67])
torch.Size([67])
torch.Size([66])

Is this a bug or under expectation?

Thanks

Based on the code snippet the shapes seem to be expected.
If you print the start and end indices from each worker you would see:

# num_workers=2
worker 0, start 0, end 250
worker 1, start 250, end 500

# num_workers=3
worker 0, start 0, end 167
worker 1, start 167, end 334
worker 2, start 334, end 500

Each worker will thus yield full batches until the last one contains the remaining samples, as you’ve explicitly defined this logic in the worker_init_fn and __iter__.

Hi @ptrblck,

Thanks for your help!

Each worker yields full batches first until none of them has a full batch, and then they yield respective incomplete batches even though virtually they can be combined into full batches, with drop_last=False in Dataloader.
Is this what you mean?

My expectation was Dataloader would take the job of combining examples from different workers into complete batches when there are sufficient examples.

I think this feature request is still not implemented and currently each worker creates the batch itself, so your explanation sounds right.

I see. Thanks for help!