Pytorch indexing data batch in dataloader

Hi all,
We have a question:
In order to train the model, we use two datasets, that is, dataset_a and dataset_b.
We feed the model using only one batch data, either from dataset_a or from dataset_b.
How can we realize this using data_loader in pytorch?
The code is shown as follows:
for step in range(10):
if (step+1) % 2 == 0
batch_data = dataset_a[step]
else
batch_data = dataset_b[step]
output = model(batch_data)

Thank you very much.

Hi,

I think the built in ConcatDataset does exactly what you want.

Hi,
Thanks.
Our model needs two batches data in one forward and backward step.
one batch data from dataset_a and one batch data from dataset_b.
How can we get two batches when using concatDataset?

Sorry in your original question you said that you need a batch from either dataset. Which you can do with ConcatDataset.
If you need a batch from each dataset, then you can just do that in your for loop: for sample_a, sample_b in zip(loader_a, loader_b):.

I am very sorry I did not express my question clearly.
In fact, we have four data sets. In a forward and backward step, the model needs two batches which from two of the four data sets. Moreover, the model goes through four data sets in an epoch. The batches data is not repeated in an epoch.
The pseudo code is as follows:

For step in range(100)
Randomly sample two data sets. (like dset A and dset B);
Load two batches (like batch a and batch b) from these two data sets; # In next round, if we sample dsetA and dsetB, the batch a’ and batch b’ is not the same as batch a and batch b. Finally, we need go through all samples in four datasets. 
Feed into the model;
End  

I guess a simple solution would be to sample one batch from each dataset using zip at each step and them randomly discard 2 of them.
Otherwise, you have a for loop that counts indices, and manually manage the iterators for each of the 4 dataset, calling next() only on the two that you selected. But I am not sure how you can handle the fact that you are going to “finish” some dataset before others.

How about provide a default in calling next(), then use it to check the “finish”?

https://docs.python.org/3.6/library/functions.html#next