NotImplementedError after torch.utils.data.random_split

I have an iterable dataset (and dataloader for it) which works fine. But after splitting it using torch.utils.data.random_split, and making the split datasets into a dataloader I’m getting a NotImplementedError.

Here’s a small example to reproduce the error:

class TestDataset(torch.utils.data.IterableDataset):
    def __init__(self):
        super().__init__()
        self.n = 10
        
    def __len__(self):
        return self.n
    
    def __iter__(self):
        for i in range(self.n):
            yield i
            
test_dataset = TestDataset()
test_dataset1, test_dataset2 = torch.utils.data.random_split(test_dataset, lengths=[8, 2])
test_dataloader1 = torch.utils.data.DataLoader(test_dataset1, batch_size=2)
next(iter(test_dataloader1))  # NotImplementedError

Why am I getting this error and how can I split an Iterable style dataset then?

The Subset tries to index the dataset, which doesn’t work for IterableDatasets:

test_dataset = TestDataset()
test_dataset[0]
> NotImplementedError

Since you already seem to define a__len__ method I would suggest to use the map-style torch.utils.data.Dataset and define the __getitem__ method, which would work using random_split.

Unfortunately that’s not possible. I can calculate the len but can’t keep all the data in memory to make it indexable. Is there no way to split this kind of a dataset? For example in tensorflow you can .take and .skip to generate a train and validation set.

You don’t have to preload the data and can lazily load it as is described in this tutorial.

I don’t have a traditional dataset so unfortunately I can’t do that. It has to be generated on the fly and can’t be indexed into.

I’m just gonna hack something manually since I really need to get this done fast. Was hoping there would be some kind of a take/skip API since that should work regardless of iterable data or map-style.

In that case I don’t really understand the overall use case.
IterableDataset is used for data “streams”, i.e. Datasets, which cannot directly be indexed.
The map-style datasets define a __getitem___ as well as a length, so that indexing via dataset[index] would load the corresponding sample. random_split uses this property to split the indices into the desired ranges and returns the Subsets.
However, based on your description you cannot calculate the length and don’t have a __getitem__ (so the IterableDataset seems to be the right choice). Nevertheless you would like to “split” this dataset using indices, which doesn’t fit your described use case, since you can’t use indices to load samples.