Sample equally from ConcatDataset

Let’s say I have two datasets such as MNIST and SVHN. I can use ConcatDataset to concatenate these,
however, for each iteration, I need half the batch size MNIST images and the other half SVHN.

I am now able to do this by this hack here-

So, I have to artificially blow up the size of the dataset. Instead, from the ConcatDataset, if I can sample efficiently such that I get 50% from one dataset and 50% from another, that would make life a lot more simpler. This kind of operation is heavily used in visual domain adaptation problems

Thanks

Or you can create different data loaders for both of them.

loader1 = torch.utils.data.DataLoader(...)
loader2 = torch.utils.data.DataLoader(...)

for data1, labels1, data2, labels2 in zip(loader1, loader2):
    # You can also shuffle the data if you want here
    data = torch.stack([data1, data2])

Thank you Kushaj. I am at the moment doing the same.
The problem is that ‘for’ loop ends when the batch_idx becomes min(len(loader1),len(loader2) )
Say for example a dataset1 has 1000 images and another dataset2 has 500. With this approach we will iterate only till 500 images in both the datasets. The remaining 500 images from dataset1 will not be used.
Hence, as I link it in my question, through a hack, I try to artificially increase the size of dataset2 to 1000 images. I wish to avoid this.

Use this example as reference. You can call them outside the for loop by creating iters

a = torch.tensor([1., 2, 3])
class Dataset(torch.utils.data.Dataset):
    def __init__(self, x):
        self.x = x
    def __len__(self):
        return self.x.size(0)
    def __getitem__(self, idx):
        return self.x[idx]
dataset = Dataset(a)
loader = torch.utils.data.DataLoader(dataset)

i_loader = iter(loader)
for i in range(10):
    try:
        a = next(i_loader)
        print(a)
    except StopIteration:
        i_loader = iter(loader)
tensor([1.])
tensor([2.])
tensor([3.])
tensor([1.])
tensor([2.])
tensor([3.])
tensor([1.])
tensor([2.])

Thanks. I was looking at a solution that uses ConcatDataset and followed by a some sampler such as SubsetRandomSampler . I have a feeling that this solution will be cleaner and faster.

@ptrblck - any suggestions? I was using your hack that I linked in my OP.