I have an iterable dataset (and dataloader for it) which works fine. But after splitting it using torch.utils.data.random_split
, and making the split datasets into a dataloader I’m getting a NotImplementedError
.
Here’s a small example to reproduce the error:
class TestDataset(torch.utils.data.IterableDataset):
def __init__(self):
super().__init__()
self.n = 10
def __len__(self):
return self.n
def __iter__(self):
for i in range(self.n):
yield i
test_dataset = TestDataset()
test_dataset1, test_dataset2 = torch.utils.data.random_split(test_dataset, lengths=[8, 2])
test_dataloader1 = torch.utils.data.DataLoader(test_dataset1, batch_size=2)
next(iter(test_dataloader1)) # NotImplementedError
Why am I getting this error and how can I split an Iterable style dataset then?
The Subset
tries to index the dataset, which doesn’t work for IterableDataset
s:
test_dataset = TestDataset()
test_dataset[0]
> NotImplementedError
Since you already seem to define a__len__
method I would suggest to use the map-style torch.utils.data.Dataset
and define the __getitem__
method, which would work using random_split
.
Unfortunately that’s not possible. I can calculate the len
but can’t keep all the data in memory to make it indexable. Is there no way to split this kind of a dataset? For example in tensorflow you can .take
and .skip
to generate a train and validation set.
You don’t have to preload the data and can lazily load it as is described in this tutorial.
I don’t have a traditional dataset so unfortunately I can’t do that. It has to be generated on the fly and can’t be indexed into.
I’m just gonna hack something manually since I really need to get this done fast. Was hoping there would be some kind of a take/skip
API since that should work regardless of iterable data or map-style.
In that case I don’t really understand the overall use case.
IterableDataset
is used for data “streams”, i.e. Dataset
s, which cannot directly be indexed.
The map-style datasets define a __getitem___
as well as a length, so that indexing via dataset[index]
would load the corresponding sample. random_split
uses this property to split the indices into the desired ranges and returns the Subset
s.
However, based on your description you cannot calculate the length and don’t have a __getitem__
(so the IterableDataset
seems to be the right choice). Nevertheless you would like to “split” this dataset using indices, which doesn’t fit your described use case, since you can’t use indices to load samples.