What is the recommended approach to combine two instances from torch.utils.data.Dataset?
I came up with two ideas:
Wrapper-Dataset:
class Concat(Dataset):
def __init__(self, datasets):
self.datasets = datasets
self.lengths = [len(d) for d in datasets]
self.offsets = np.cumsum(self.lengths)
self.length = np.sum(self.lengths)
def __getitem__(self, index):
for i, offset in enumerate(self.offsets):
if index < offset:
if i > 0:
index -= self.offsets[i-1]
return self.datasets[i][index]
raise IndexError(f'{index} exceeds {self.length}')
def __len__(self):
return self.length
Using itertools.chain
loader = itertools.chain(*[MyDataset(f'file{i}') for i in range(1, 4)])
Ah yes itertools.chain would only do one epoch so we would be better of with something like:
x = itertools.repeat(itertools.chain.from_iterable([dataset1, dataset2]), times=epochs)
next(next(iter(x))
# or
for epoch in x:
for (inputs, targets) in epoch:
print(inputs)
Not sure if that’s going to work. It can break if itertools.chain iterator is not immutable (and it’s probably not). It would be simpler to do this (or use the wrapper dataset):
for epoch in range(num_epochs):
dset = itertools.chain(...)
dloader = # create DataLoader
for ... in dloader:
...