Work with different dataset in parallel

Hi everyone,
I have the following problem:
I have 2 different dataset of images, targets; in principle the 2 dataset may have different number of samples, I need to:

  • Mantain divided the element of the 2 dataset, i.e. inside a generic batch only element of a single dataset has to appear

  • iterate on both dataset in the same time, i.e. at each step select one batch for each dataset, perform some operations, and then pass to the next couple of batches

The solution that I found so far is the following:
using torch.utils.data.TensorDataset to wrap both the dataset and torch.utils.data.DataLoader to itarete throught them

        x0, y0 = [], []
        for x, y in train_data0:
            
            x0.append(x)
            y0.append(y)
        x1, y1 = [], []
        for x, y in train_data1:
            
            x1.append(x)
            y1.append(y)
        
        
        totdataset =  torch.utils.data.TensorDataset(torch.stack(x0), torch.tensor(y0), torch.stack(x1), torch.tensor(y1))
        dataloader = torch.utils.data.DataLoader(totdataset, batch_size=10, shuffle = True, num_workers = 1)
for index, (a, b, c, d) in enumerate(dataloader):
#here some operation on the 2 datasets' batches

the problem now is that if train_data0 and train_data1 has different number of elements, since the batch size is the same for each tensor (10 in the above example) I’ll have different number of batches for the 2 dataset ending in a Size mismatch between tensors entering in the for cycle (last line)
How can I solve this problem?
Is there a way to fix the number of batches instead of the batch sizes for the dataloader input tensors?
Is there a different (easy) way to solve my initial problem?

Sooo realise that available dataset classes only cover traditional problems in CV.
Most of the times you need to write your own generic dataset torch.utils.data.Dataset.
In fact you can write a dataset which contains two other datasets like:

import torch
from torch.utils.data import Dataset,TensorDataset

class MyDataset(Dataset):
    def __init__(self, *args,**kwargs):
        super(MyDataset, self).__init__()
        self.dataset_images = TensorDataset(*args,**kwargs)
        self.dataset_targets = TensorDataset(*args,**kwargs)
    def __len__(self):
        return None # Be creative here? largest of both... dunno
    def __getitem__(self, idx):
        return self.dataset_images[idx], self.dataset_targets[idx]
    # Bu here you can randomly sample both, the shorter, just code anything you need. 

Basically, you can code anything you want using any python code or libraries. You can randomly sample both subsets, you can make code to be deterministic and them random for the shorter… everything is up to you.
Inside init you define any tool you need and in getitem you carry out the workload (opening files and so)

1 Like

Hi Juan, Thank you for your reply. I didn’t get how to proceed in practice to solve my specific problem:

given in input the tensors corresponding to the 2 dataset during the initialization I can create one single dataset where the 2 “subdataset keeps divided”: this step is fine, I get it.

Now let say that I have 800 samples in the first “subdataset” and 400 in the second one.
How should I set the __getitem__ method to get back in each call from the dataloader B images (and corresponding targets) from the first “dataset” and B/2 from the second one (so that I recall all the element of the 2 “subdataset” using a common number of batches)?

Here, __getittem__ loads a single sample.
Once you call the dataloader, it will load the N amount of samples you set by calling N times __getittem__ and stacking the tensors in a smart way. This is what they call Automatic Batching.

You can disable that, setting the dataloader in batch mode. This is, __getitem__ will no longer load a sample but a batch (ofc you have to code the stacking and so own your self). This will just convert numpy arrays in pytorch tensors but that’s all. More info here: torch.utils.data — PyTorch 1.10.0 documentation
Scroll down until you see Loading Batched and Non-Batched Data and ** Disable automatic batching**