ChainDataset + random_split oddity

My training data is split into multiple files so I have to use ConcatDataset, then I split them into train/test/valid sets using random_split.

I don’t think this is a very uncommon situation but I cannot get this to work since one wants an iteratable dataset and the other wants an indexable dataset, so I end up with something the works in neither situation. Both

next(iter(train_ds))

and

train_ds[0]

give me this error:

NotImplementedError                       Traceback (most recent call last)
<ipython-input-22-de92b112278f> in <module>
----> 1 next(iter(train_ds))

~/.local/lib/python3.6/site-packages/torch/utils/data/dataset.py in __getitem__(self, idx)
    255 
    256     def __getitem__(self, idx):
--> 257         return self.dataset[self.indices[idx]]
    258 
    259     def __len__(self):

~/.local/lib/python3.6/site-packages/torch/utils/data/dataset.py in __getitem__(self, index)
     23 
     24     def __getitem__(self, index):
---> 25         raise NotImplementedError
     26 
     27     def __add__(self, other):

NotImplementedError: 

Switching to ConcatDataset gives me this:

AssertionError: ConcatDataset does not support IterableDataset

What Am I doing wrong?

Should add that I am using a custom IteratableDataset that also implements indexing.

Unfortunately I cannot fix this by joining the raw files and loading them as one since the data is interpreted as sequences and joining them would create invalid sequences right where two files meet.