IMHO it is possible, but we have to load the date 2 or more times.
First, if our dataset returns pandas.DataFrame we can use:
import torch.utils.data as td
class SplitDataSet (td.IterableDataset):
def __init__ (self, dataset, ranges : dict [ty.Tuple], generator = torch.Generator ()):
self.dataset = dataset
self.ranges = ranges
self.generator = generator
def __iter__ (self):
self.iter = iter (self.dataset)
self.rng = self.generator.clone_state ()
return self
def __next__ (self):
df = next (self.iter)
size = len (df)
rs = torch.rand (size, generator = self.rng)
res = {}
for key, xy in self.ranges.items ():
idxs = (xy [0] <= rs) & (rs < xy [1])
res [key] = df.loc [idxs.numpy ()]
return res
NB: We must create the Dataloader by:
ranges = {'train': (0.0, 0.8), 'val': (0.8, 1.0), 'test':(0.8,1.0)}
dl = td.DataLoader (SplitDataSet (ds, ranges), batch_size = None)
I.e. our batches are the dataframes itself. We must transform them into Tensors.
Alternativly, we can split each item by itself:
class SplitItemDataSet (td.IterableDataset):
def __init__ (self, dataset, ranges : dict [ty.Tuple], generator = torch.Generator ()):
self.dataset = dataset
self.ranges = ranges
self.generator = generator
def __iter__ (self):
self.iter = iter (self.dataset)
self.rng = self.generator.clone_state ()
return self
def __next__ (self):
x = next (self.iter)
num = int (torch.rand ((1,1), generator = self.rng))
res = {}
for key, xy in self.ranges.items ():
use = (xy [0] <= num) & (num < xy [1])
res [key] = x if use else None
return res
class FilterDataSet (td.IterableDataset):
def __init__ (self, dataset : td.IterableDataset, pred : ty.Callable):
self.dataset = dataset
self.pred = pred
def __iter__ (self):
self.iter = iter (self.dataset)
return self
def __next__ (self):
while True:
x = next (self.iter)
if self.pred (x): return x
dl = td.DataLoader (FilterDataSet (SplitItemDataSet (dataset, ranges),
lambda d: d ['train'] is not None))