ChainDataset + random_split oddity

My training data is split into multiple files so I have to use ConcatDataset, then I split them into train/test/valid sets using random_split.

I don’t think this is a very uncommon situation but I cannot get this to work since one wants an iteratable dataset and the other wants an indexable dataset, so I end up with something the works in neither situation. Both

next(iter(train_ds))

and

train_ds[0]

give me this error:

NotImplementedError                       Traceback (most recent call last)
<ipython-input-22-de92b112278f> in <module>
----> 1 next(iter(train_ds))

~/.local/lib/python3.6/site-packages/torch/utils/data/dataset.py in __getitem__(self, idx)
    255 
    256     def __getitem__(self, idx):
--> 257         return self.dataset[self.indices[idx]]
    258 
    259     def __len__(self):

~/.local/lib/python3.6/site-packages/torch/utils/data/dataset.py in __getitem__(self, index)
     23 
     24     def __getitem__(self, index):
---> 25         raise NotImplementedError
     26 
     27     def __add__(self, other):

NotImplementedError: 

Switching to ConcatDataset gives me this:

AssertionError: ConcatDataset does not support IterableDataset

What Am I doing wrong?

Should add that I am using a custom IteratableDataset that also implements indexing.

Unfortunately I cannot fix this by joining the raw files and loading them as one since the data is interpreted as sequences and joining them would create invalid sequences right where two files meet.

It’s only 9 months later :wink:
But in any case: could you describe how IterableDatasets should be concatenated? They don’t have a length, since they are used for e.g streaming data, so I would assume that you could just switch the stream once it’s exhausted?

@ptrblck I want to use 2 IterableDatasets for each epoch. As you said, once the iterator reaches the end of the first IterableDataset, I want to use the second ‘IterableDataset’. How do I achieve this?

You can just continue with the second IterableDataset as seen here:

class MyIterableDataset(torch.utils.data.IterableDataset):
    def __init__(self, start, end):
        super(MyIterableDataset).__init__()
        assert end > start, "this example code only works with end >= start"
        self.start = start
        self.end = end
        
    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None:  # single-process data loading, return the full iterator
            iter_start = self.start
            iter_end = self.end
        else:  # in a worker process
            # split workload
            per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
            worker_id = worker_info.id
            iter_start = self.start + worker_id * per_worker
            iter_end = min(iter_start + per_worker, self.end)
        return iter(range(iter_start, iter_end))


ds1 = MyIterableDataset(start=3, end=7)
loader1 = DataLoader(ds1, num_workers=1)
ds2 = MyIterableDataset(start=7, end=10)
loader2 = DataLoader(ds2, num_workers=1)

for epoch in range(2):
    for data in loader1:
        print(data)
        # train with first loader
        
    for data in loader2:
        print(data)
        # train with second loader

@ptrblck How is it different compared to ChainDataset? If they are the same, why was ChainDataset not your first recommendation?

ChainDataset would also work and you can pick any approach that fits your needs.
I don’t know how the initial request would work with IterableDatasets as they are apparently streaming on one hand but also support indexing. The latter would be needed for the splits I would assume, but also the author didn’t follow up so I don’t know how it was solved.

@ptrblck Thanks for your quick answer!