Composing `Dataset`-s

Abstract question:

Can I compose custom Datasets and expect everything to work fine with no surprises?
Something like

dat1 = Dat1()
dat2 = Dat2(dat1)
dat3 = Dat3(dat2)
dataloader = data.DataLoader(dat3)

where datasets are defined as

class Dat3(data.Dataset):
  def __init__(self, dat2, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.dat2 = dat2
  def __get_item__(self, index):
    var2 = dat2.__getitem__(index)
    out = operation(var2)
    return out

Particular question:
I am implementing an interactive image segmentation application and I need to load the data in roughly three steps:

  1. Load one image and its masks of some segmentation dataset and return them
  2. Select one target region from between the masks and return it
  3. Sample points from inside a target region and return them
  4. Create a dataloader that loads image, target region and points

My idea was to apply the “composition over inheritance” principle, which means creating 1. SegDataset(data.Dataset), 2. RegionDataset(data.Dataset) which I init as RegionDataset(seg_dataset, *args, **kwargs) and 3. IISDataset(data.Dataset) which can be instantiated as IISDataset(region_dataset, *args, **kwargs).

Will this work well if I use the last one to init a DataLoader?
The difference between doing this and making custom datasets (not inheriting from data.Dataset) is only the inheritance, which may be good (I can dataload any intermediate dataset) but also bad (maybe I should be careful when doing distributed training). Any opinion?

I found that if I don’t inherit from data.Dataset the DataLoader still works. What is the advantage of subclassing data.Dataset?

Deriving from Dataset would raise errors e.g. if mandatory methods are missing (in your example you are using __geti_item__ instead of __getitem__ which should fail) and might thus be a cleaner approach.

Your “composition” approach should also work, but I would probably just call into the nested datasets directly instead of calling the __geittem__ (self.dat2[index] instead of self.dat2.__getitem__(index)).

1 Like

You can also try using torchdata, but it is still in prototype stage (soon to be in Beta).

The tutorial provides an example for Iterable-style (with __iter__ method implemented), but it works the same way for Map-style (with custom __getitem__ method).