How do I split an iterable dataset into training and test datasets?

I have an iterable dataset object with all of my data files. How can I split it into train and validation set. I have seen a few solutions for custom datasets but iterable does not support len() operator. torch.utils.random_sample() and torch.utils.SubsetRandomSample() don’t work.

def __init__(self):
         bla
         bla
 
 def __iter__(self):
     bla
     .
     yield batch

I don’t think you could split the samples in the IterableDataset as it’s used for e.g. streams of data.
In case your dataset is predefined, you could check if the stream source is able to split and shuffle the data.

1 Like

IMHO it is possible, but we have to load the date 2 or more times.

First, if our dataset returns pandas.DataFrame we can use:

import torch.utils.data as td
class SplitDataSet (td.IterableDataset):
  def __init__ (self, dataset, ranges : dict [ty.Tuple], generator = torch.Generator ()):
    self.dataset = dataset
    self.ranges = ranges
    self.generator = generator

  def __iter__ (self):
    self.iter = iter (self.dataset)
    self.rng = self.generator.clone_state ()
    return self

  def __next__ (self):
    df = next (self.iter)
    size = len (df)
    rs = torch.rand (size, generator = self.rng)
    res = {}
    for key, xy in self.ranges.items ():
      idxs = (xy [0] <= rs) & (rs < xy [1])
      res [key] = df.loc [idxs.numpy ()]
    return res

NB: We must create the Dataloader by:

ranges = {'train': (0.0, 0.8), 'val': (0.8, 1.0), 'test':(0.8,1.0)}
dl = td.DataLoader (SplitDataSet (ds, ranges), batch_size = None)

I.e. our batches are the dataframes itself. We must transform them into Tensors.

Alternativly, we can split each item by itself:

class SplitItemDataSet (td.IterableDataset):
  def __init__ (self, dataset, ranges : dict [ty.Tuple], generator = torch.Generator ()):
    self.dataset = dataset
    self.ranges = ranges
    self.generator = generator

  def __iter__ (self):
    self.iter = iter (self.dataset)
    self.rng = self.generator.clone_state ()
    return self

  def __next__ (self):
    x = next (self.iter)
    num = int (torch.rand ((1,1), generator = self.rng))
    res = {}
    for key, xy in self.ranges.items ():
      use = (xy [0] <= num) & (num < xy [1])
      res [key] = x if use else None
    return res

class FilterDataSet (td.IterableDataset):
  def __init__ (self, dataset : td.IterableDataset, pred : ty.Callable):
    self.dataset = dataset
    self.pred = pred
  def __iter__ (self):
    self.iter = iter (self.dataset)
    return self
  def __next__ (self):
    while True:
      x = next (self.iter)
      if self.pred (x): return x

dl = td.DataLoader (FilterDataSet (SplitItemDataSet (dataset, ranges),
           lambda d: d ['train'] is not None))