Combine / concat dataset instances

What is the recommended approach to combine two instances from torch.utils.data.Dataset?

I came up with two ideas:

  1. Wrapper-Dataset:
class Concat(Dataset):

    def __init__(self, datasets):
        self.datasets = datasets
        self.lengths = [len(d) for d in datasets]
        self.offsets = np.cumsum(self.lengths)
        self.length = np.sum(self.lengths)

    def __getitem__(self, index):
        for i, offset in enumerate(self.offsets):
            if index < offset:
                if i > 0:
                    index -= self.offsets[i-1]
                return self.datasets[i][index]
        raise IndexError(f'{index} exceeds {self.length}')

    def __len__(self):
        return self.length
  1. Using itertools.chain
loader = itertools.chain(*[MyDataset(f'file{i}') for i in range(1, 4)])
1 Like

Doesn’t itertools.chain return an iterator that will go over them only once? I think there’s no recommended way, it’s all down to a personal taste :slight_smile:

Ah yes itertools.chain would only do one epoch so we would be better of with something like:

x = itertools.repeat(itertools.chain.from_iterable([dataset1, dataset2]), times=epochs)

next(next(iter(x))

# or 

for epoch in x:
    for (inputs, targets) in epoch:
       print(inputs)

Not sure if that’s going to work. It can break if itertools.chain iterator is not immutable (and it’s probably not). It would be simpler to do this (or use the wrapper dataset):

for epoch in range(num_epochs):
    dset = itertools.chain(...)
    dloader = # create DataLoader
    for ... in dloader:
        ...
1 Like

I can not pass the itertools.chain instance to torch.utils.data.DataLoader class while the former does not support len().

See ConcatDataset updated in v0.3: http://pytorch.org/docs/master/data.html#torch.utils.data.ConcatDataset

9 Likes
def return_all_items(dataset):
  all_items = []
  for i in range(len(dataset)):
    all_items.append(dataset[i])
  return all_items
list1 = return_all_items(original_data)
list2 = return_all_items(transformed_data)
list1.extend(list2)

then converting the list to dataset object

class AugmentedDataset(Dataset):
  
  def __init__(self, combined_list, transform=None):
    self.combined_list = combined_list
    
  def __len__(self):
    return len(self.combined_list)
  
  def __getitem__(self, idx):
    sample = self.combined_list[idx]      
    return sample
  
augmented_dataset = AugmentedDataset(list1)

See also ChainDataset: https://pytorch.org/docs/master/data.html#torch.utils.data.ChainDataset