Hello Everyone,
I am trying to train a multi-task model with two different sets of the dataset for the tasks. I split the training data into train(80%) and validation data(20%). Detail of the dataset that I have created:
-
Each pair of the image and label for the Task 1 Data has been packed inside a dictionary with two keys
Total : 5255
Train Data : 4204
Validation Data : 1051 -
For task 2, data is the form of a list
Total : 1108
Train Data : 886
Validation Data : 222
I have concatenated the data using
combined_training_dataset = torch.utils.data.ConcatDataset([task_1_train_set, task_2_train_set])
Combined Train Dataset : 5090
To train the model, I want each batch to have data from only one of the datasets and data coming from each dataset one after another.
I have been following this blog to fulfill my requirement. This is how I am creating the loader:
combined_train_loader = torch.utils.data.DataLoader(dataset=combined_training_dataset,
sampler=BalancedBatchSchedulerSampler(dataset=combined_training_dataset,
batch_size=10),
batch_size=10,
shuffle=False,
pin_memory=True)
I am using code from unbalanced loader part of the blog but getting this error:
<ipython-input-14-7948015c2ea8> in print_dataloader(loader, num_of_image, avDev)
6 def print_dataloader(loader, num_of_image = 5, avDev = torch.device("cuda")):
7 print("\n*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-* Training Image Samples *-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*")
----> 8 for num_batch, batch in enumerate(loader):
9 if num_batch < 2:
10 if batch.__class__.__name__ == "dict" :
~/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py in __next__(self)
343
344 def __next__(self):
--> 345 index = self._next_index() # may raise StopIteration
346 data = self._dataset_fetcher.fetch(index) # may raise StopIteration
347 if self._pin_memory:
~/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py in _next_index(self)
316
317 def _next_index(self):
--> 318 return next(self._sampler_iter) # may raise StopIteration
319
320 def __next__(self):
~/anaconda3/lib/python3.7/site-packages/torch/utils/data/sampler.py in __iter__(self)
198 def __iter__(self):
199 batch = []
--> 200 for idx in self.sampler:
201 batch.append(idx)
202 if len(batch) == self.batch_size:
<ipython-input-26-e18eba551d35> in __iter__(self)
89 else:
90 # the second unbalanced dataset is changed
---> 91 sampler = ExampleImbalancedDatasetSampler(cur_dataset)
92 samplers_list.append(sampler)
93 cur_sampler_iterator = sampler.__iter__()
<ipython-input-26-e18eba551d35> in __init__(self, dataset, indices, num_samples, callback_get_label)
28 label_to_count = {}
29 for idx in self.indices:
---> 30 label = self._get_label(dataset, idx)
31 if label in label_to_count:
32 label_to_count[label] += 1
<ipython-input-26-e18eba551d35> in _get_label(self, dataset, idx)
63 """
64 def _get_label(self, dataset, idx):
---> 65 return dataset.samples[idx].item()
66
67
AttributeError: 'Subset' object has no attribute 'samples'
Can you suggest a solution or any other approach to achieve what I want?
TIA