Train simultaneously on two datasets

If you are looking for using multiple dataloaders at the same time this should work


class cat_dataloaders():
    """Class to concatenate multiple dataloaders"""

    def __init__(self, dataloaders):
        self.dataloaders = dataloaders
        len(self.dataloaders)

    def __iter__(self):
        self.loader_iter = []
        for data_loader in self.dataloaders:
            self.loader_iter.append(iter(data_loader))
        return self

    def __next__(self):
        out = []
        for data_iter in self.loader_iter:
            out.append(next(data_iter)) # may raise StopIteration
        return tuple(out)

Here is a quick example

class DEBUG_dataset(Dataset):
    def __init__(self,alpha):
        self.d = (torch.arange(20) + 1) * alpha
    def __len__(self):
        return self.d.shape[0]
    def __getitem__(self, index):
        return self.d[index]

train_dl1 = DataLoader(DEBUG_dataset(10), batch_size = 4,num_workers = 0 , shuffle=True)
train_dl2 = DataLoader(DEBUG_dataset(1), batch_size = 4,num_workers = 0 , shuffle=True)
tmp = cat_dataloaders([train_dl1,train_dl2])
for x in tmp:
    print(x)

output is

(tensor([140, 160, 130,  90]), tensor([ 5, 10,  8,  9]))
(tensor([120,  30, 170,  70]), tensor([15, 17, 18,  7]))
(tensor([180,  50, 190,  80]), tensor([ 6, 14,  3,  2]))
(tensor([ 10,  40, 150, 100]), tensor([11, 13,  4,  1]))
(tensor([ 60, 200, 110,  20]), tensor([19, 12, 20, 16]))
1 Like