If you are looking for using multiple dataloaders at the same time this should work
class cat_dataloaders():
"""Class to concatenate multiple dataloaders"""
def __init__(self, dataloaders):
self.dataloaders = dataloaders
len(self.dataloaders)
def __iter__(self):
self.loader_iter = []
for data_loader in self.dataloaders:
self.loader_iter.append(iter(data_loader))
return self
def __next__(self):
out = []
for data_iter in self.loader_iter:
out.append(next(data_iter)) # may raise StopIteration
return tuple(out)
Here is a quick example
class DEBUG_dataset(Dataset):
def __init__(self,alpha):
self.d = (torch.arange(20) + 1) * alpha
def __len__(self):
return self.d.shape[0]
def __getitem__(self, index):
return self.d[index]
train_dl1 = DataLoader(DEBUG_dataset(10), batch_size = 4,num_workers = 0 , shuffle=True)
train_dl2 = DataLoader(DEBUG_dataset(1), batch_size = 4,num_workers = 0 , shuffle=True)
tmp = cat_dataloaders([train_dl1,train_dl2])
for x in tmp:
print(x)
output is
(tensor([140, 160, 130, 90]), tensor([ 5, 10, 8, 9]))
(tensor([120, 30, 170, 70]), tensor([15, 17, 18, 7]))
(tensor([180, 50, 190, 80]), tensor([ 6, 14, 3, 2]))
(tensor([ 10, 40, 150, 100]), tensor([11, 13, 4, 1]))
(tensor([ 60, 200, 110, 20]), tensor([19, 12, 20, 16]))