I need to iterate simoultaneously on multiple dataset (let’s say 2 dataset) keeping the element of each of them isolated (each batch must contain only element of one dataset and for each step I want to work with one batch from each dataset). To do so I think that torch.utils.data.TensorDataset
can be the right tool, for example:
dataset = torch.utils.data.TensorDataset(dataset1, dataset2)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)
for index, (xb1, xb2) in enumerate(dataloader):
....
where xb1
refers to the input data and target associated to one of the 2 dataset.
My first question is: have I understood well the use of torch.utils.data.TensorDataset
? Does this approach solve my problem?
My second question is: how to put a sampler in the dataloader in this situation? can I, for example, define 2 indeces tensor Idx1
and Idx2
and put in DataLoader
an option like sampler = (Idx1, Idx2)
EDIT:
An alternative approach could be to create a dataloader for each dataset, each one with his own sampler and use zip() to iterate simoultaneously on the 2 dataset. Is there a more clean solution for that (also beacosu I read that (source):
cycle()
andzip()
might create a memory leakage problem - especially when using image datasets!
)?