Dataloader for multi-task Model with unbalanced data

Hello Everyone,

I am trying to train a multi-task model with two different sets of the dataset for the tasks. I split the training data into train(80%) and validation data(20%). Detail of the dataset that I have created:

  • Each pair of the image and label for the Task 1 Data has been packed inside a dictionary with two keys
    Total : 5255
    Train Data : 4204
    Validation Data : 1051

  • For task 2, data is the form of a list
    Total : 1108
    Train Data : 886
    Validation Data : 222

I have concatenated the data using

combined_training_dataset = torch.utils.data.ConcatDataset([task_1_train_set, task_2_train_set])

Combined Train Dataset : 5090

To train the model, I want each batch to have data from only one of the datasets and data coming from each dataset one after another.

I have been following this blog to fulfill my requirement. This is how I am creating the loader:

combined_train_loader = torch.utils.data.DataLoader(dataset=combined_training_dataset, 
                                           sampler=BalancedBatchSchedulerSampler(dataset=combined_training_dataset,
                                                                       batch_size=10),
                                           batch_size=10,
                                           shuffle=False,
                                           pin_memory=True)

I am using code from unbalanced loader part of the blog but getting this error:

<ipython-input-14-7948015c2ea8> in print_dataloader(loader, num_of_image, avDev)
      6 def print_dataloader(loader, num_of_image = 5, avDev = torch.device("cuda")):
      7     print("\n*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-* Training Image Samples *-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*")
----> 8     for num_batch, batch in enumerate(loader):
      9         if num_batch < 2:
     10             if batch.__class__.__name__ == "dict" :

~/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py in __next__(self)
    343 
    344     def __next__(self):
--> 345         index = self._next_index()  # may raise StopIteration
    346         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    347         if self._pin_memory:

~/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py in _next_index(self)
    316 
    317     def _next_index(self):
--> 318         return next(self._sampler_iter)  # may raise StopIteration
    319 
    320     def __next__(self):

~/anaconda3/lib/python3.7/site-packages/torch/utils/data/sampler.py in __iter__(self)
    198     def __iter__(self):
    199         batch = []
--> 200         for idx in self.sampler:
    201             batch.append(idx)
    202             if len(batch) == self.batch_size:

<ipython-input-26-e18eba551d35> in __iter__(self)
     89             else:
     90                 # the second unbalanced dataset is changed
---> 91                 sampler = ExampleImbalancedDatasetSampler(cur_dataset)
     92             samplers_list.append(sampler)
     93             cur_sampler_iterator = sampler.__iter__()

<ipython-input-26-e18eba551d35> in __init__(self, dataset, indices, num_samples, callback_get_label)
     28         label_to_count = {}
     29         for idx in self.indices:
---> 30             label = self._get_label(dataset, idx)
     31             if label in label_to_count:
     32                 label_to_count[label] += 1

<ipython-input-26-e18eba551d35> in _get_label(self, dataset, idx)
     63     """
     64     def _get_label(self, dataset, idx):
---> 65         return dataset.samples[idx].item()
     66 
     67 

AttributeError: 'Subset' object has no attribute 'samples'

Can you suggest a solution or any other approach to achieve what I want?

TIA

You would most likely have to change the call in _get_label here:

class ExampleImbalancedDatasetSampler(ImbalancedDatasetSampler):
    """
    ImbalancedDatasetSampler is taken from https://github.com/ufoym/imbalanced-dataset-sampler/blob/master/sampler.py
    In order to be able to show the usage of ImbalancedDatasetSampler in this example I am editing the _get_label
    to fit my datasets
    """
    def _get_label(self, dataset, idx):
        return dataset.samples[idx].item()

I haven’t checked the reference implementation, but based on the function name I assume you are supposed to return the target tensor here.

@ptrblck I am not sure what exactly you mean by target tensor. I have two dataset with different types of elements in it (dict and list). Do I need to check for the dataset and return tensor accordingly?

Based on the blog post you shared, it seems you are dealing with two datasets.
One is balanced and uses the RandomSampler, while the other is imbalanced and calls into ExampleImbalancedDatasetSampler.

Inside the ExampleImbalancedDatasetSampler, dataset.samples is used, which is apparently raising the error.
I’m not sure, what kind of datasets you are using, but you should change the dataset.samples call to whatever fits your dataset.