Two DataLoaders from two different datasets within the same loop

So I am trying to have two data loaders emit a batch of data each within the training loop. Like so:


data_loader1 = torch.utils.data.DataLoader(train_set1, batch_size=run.batch, shuffle=run.shuf, drop_last=True)

data_loader2 = torch.utils.data.DataLoader(train_set2, batch_size=run.batch, shuffle=run.shuf, drop_last=True)

for image_batch, labels in data_loader1:   
              image_batch2, labels2 = next(iter(data_loader2))
              #code within training loop
               

This works right up till the point that the second data loader runs out of images… apparently next(iter()) will not go back to the beginning of the dataset.
This post explains: https://stackoverflow.com/questions/48180209/stop-iteration-error-when-using-next

So the problem is to have two data loaders emit a batch each within the same loop… but without using next(iter)) or creating nested for loops (for computing reasons).

Any ideas???

I am sure you might have already thought of this (I am assuming both have the same size).

for item1, item2 in zip(dataloader1, dataloader2):
    image_batch1, labels1 = item1
    image_batch2, labels2 = item2

The nature of your question suggests that the size of the two datasets is different. Am I right?

Ahh yeah forgot to mention that, thanks. Yeah the two datasets are of different sizes.

Correct me if I don’t understand it right. You want your inner dataloader object to go back to the beginning when using next. In that case, did you try itertools.cycle()?

1 Like

I did not, but having looked into it, I think that will work. Thanks. Ill confirm once I have it working.

I believe this will work. But yes, please let me know if it doesn’t:

from itertools import cycle
for item1, item2 in zip(dataloader1, cycle(dataloader2)):
    image_batch1, labels1 = item1
    image_batch2, labels2 = item2
2 Likes

This will work, but, it won’t shuffle the dataloader2 samples again when it starts again.

import torch
from torch.utils.data import Dataset, DataLoader
from itertools import cycle

dataset1 = torch.tensor([0, 1, 2, 3, 4, 5])
dataset2 = torch.tensor([10, 11, 12, 13, 14, 15, 16, 17,23, 123, 25, 34, 56, 78, 89])

dataloader1 = DataLoader(dataset1, batch_size=2, shuffle=True, num_workers=3)
dataloader2 = DataLoader(dataset2, batch_size=2, shuffle=True, num_workers=3)

for i, (data1, data2) in enumerate(zip(cycle(dataloader1), dataloader2)):
  x = data1
  y = data2
  print(x, y)

# tensor([1, 3]) tensor([17, 11])
# tensor([4, 0]) tensor([34, 89])
# tensor([5, 2]) tensor([13, 14])
# tensor([1, 3]) tensor([10, 25])
# tensor([4, 0]) tensor([ 12, 123])
# tensor([5, 2]) tensor([56, 78])
# tensor([1, 3]) tensor([16, 15])
# tensor([4, 0]) tensor([23])

It’s better to use iter(dataloader) when the shorter dataloader runs out of data.

iterations = 8

for i in range(iterations):
    data2 = next(iterloader2)
    try:
        data1 = next(iterloader1)
    except StopIteration:
        iterloader1 = iter(dataloader1)
        data1 = next(iterloader1)

    print(data1, data2)
# tensor([2, 1]) tensor([34, 12])
# tensor([0, 5]) tensor([15, 16])
# tensor([3, 4]) tensor([14, 56])
# tensor([0, 2]) tensor([25, 13])
# tensor([5, 1]) tensor([23, 11])
# tensor([3, 4]) tensor([89, 78])
# tensor([2, 5]) tensor([ 17, 123])
# tensor([3, 0]) tensor([10])

See this link for more details on iterating over dataloader: Issue

2 Likes

import torch
from torch.utils.data import DataLoader

if name == ‘main’:

dataset1 = torch.tensor([0, 1, 2, 3, 4, 5])
dataset2 = torch.tensor([10, 11, 12, 13, 14, 15, 16, 17, 23, 123, 25, 34, 56, 78, 89, 555, 556, 557, 558, 559, 560, 561, 562, 563])

dataloader1 = DataLoader(dataset1, batch_size = 2, shuffle = True, num_workers = 3)
dataloader2 = DataLoader(dataset2, batch_size = 2, shuffle = True, num_workers = 3)

dataloader_iterator = iter(dataloader1)
for batch in enumerate(dataloader2):

    try:
        batch2 = next(dataloader_iterator)
    except StopIteration:
        dataloader_iterator = iter(dataloader1)
        batch2 = next(dataloader_iterator)

    print(batch, batch2)