Sample concat dataset and sample one batch from one dataset at a time

I’m working on multi-task learning. Task A and Task B share the encoder, but decoders are different.
So I have to differentiate the dataset A and dataset B in each iteration.

It seems if I simply do

dataset = ConcatDataset([datasetA, datasetB])

The batch during the enumeration will contain samples from both A and B. But I want a batch to either come from A or B during training.

Any suggestions?

1 Like

You could use two DataLoaders, create the iterators using loader_iter = iter(loader), and grab the next batch in each iteration as from the loader you want via next(loader_iter).
This approach would give you the flexibility to apply complicated conditions when to use which dataset.
On the other hand, if you want to switch between both datasets in each iteration, you could create a custom sampler and create the indices for the ConcatDataset as you wish.
In the simplest case you could return the indices as: [dataA_idx0, dataA_idx1, dataA_idx2, ... dataB_idx0, ...].

2 Likes

Hi there! I am trying to accomplish this same thing but with the sampler method. if I want to create batches of 4 from either dataset_A or dataset_B, and iterate entirely through both datasets, what would that sampler look like for the torch.utils.data.ConcatDataset? having trouble understanding the format of what the sampler should return. any help you could give would be greatly appreciated. thanks!

Take a look at this post to see how a BatchSampler can be used to provide a all indices of the current batch to the Dataset.__getitem__ and write a custom implementation sampling from your datasets by reusing the BatchSampler implementation.

1 Like

ok so i have this code for a combination dataset and dataloader:

# combined dataset class
class CombinationDataset(torch.utils.data.DataLoader):
    def __init__(self, datasets):
        self.datasets = datasets
    def __len__(self):
        return(sum([dataset.__len__() for dataset in self.datasets]))
    def __getitem__(self, indicies):
        dataset_idx = indicies[0]
        data_idx = indicies[1]
        print(indicies)
        return self.datasets[dataset_idx].__getitem__(data_idx)

# class that will take in multiple samplers and output batches from a single dataset at a time
class ComboBatchSampler():
    
    def __init__(self, samplers, batch_size, drop_last):
        
        if not isinstance(batch_size, int) or isinstance(batch_size, bool) or \
                batch_size <= 0:
            raise ValueError("batch_size should be a positive integer value, "
                             "but got batch_size={}".format(batch_size))
        if not isinstance(drop_last, bool):
            raise ValueError("drop_last should be a boolean value, but got "
                             "drop_last={}".format(drop_last))
        self.samplers = samplers
        self.iterators = [iter(sampler) for sampler in self.samplers]
        self.batch_size = batch_size
        self.drop_last = drop_last

    def __iter__(self) -> Iterator[List[int]]:
        
        # define how many batches we will grab
        self.min_batches = min([len(sampler) for sampler in self.samplers])
        self.n_batches = self.min_batches * len(self.samplers)
        
        # define which indicies to use for each batch
        self.dataset_idxs = []
        random.seed(42)
        for j in range((self.n_batches//len(self.samplers) + 1)):
            loader_inds = list(range(len(self.samplers)))
            random.shuffle(loader_inds)
            self.dataset_idxs.extend(loader_inds)
        self.dataset_idxs = self.dataset_idxs[:self.n_batches]
        
        # return the batch indicies
        batch = []
        for dataset_idx in self.dataset_idxs:
            for idx in self.samplers[dataset_idx]:
                batch.append((dataset_idx, idx))
                if len(batch) == self.batch_size:
                    yield (batch)
                    batch = []
                    break
            if len(batch) > 0 and not self.drop_last:
                yield batch

    def __len__(self) -> int:
        if self.drop_last:
            return (sum([len(sampler) for sampler in self.samplers])) // self.batch_size
        else:
            return (sum([len(sampler) for sampler in self.samplers]) + self.batch_size - 1) // self.batch_size

i construct my data like this:

train_data_combined = CombinationDataset([train_data_0, train_data_1])
        sampler = ComboBatchSampler([torch.utils.data.sampler.RandomSampler(dataset) for dataset in [train_data_0, train_data_1]],
                                     batch_size=4, drop_last=True) 
        return DataLoader(train_data_combined, sampler=sampler)

however i get this output and error when trying for batch in train_loader:

[(1, 2383), (1, 2094), (1, 1501), (1, 3252)]
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/tmp/ipykernel_1022/1854105238.py in <module>
      1 #Try a couple of epochs to make sure it works (or inturrupt after a small amount of time if it does)
      2 for epoch in range(1):
----> 3     for batch in train_loader:
      4         metrics = testTrial.train_batch(batch, epoch, batch_idx)
      5         print(metrics)

/opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py in __next__(self)
    519             if self._sampler_iter is None:
    520                 self._reset()
--> 521             data = self._next_data()
    522             self._num_yielded += 1
    523             if self._dataset_kind == _DatasetKind.Iterable and \

/opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py in _next_data(self)
    559     def _next_data(self):
    560         index = self._next_index()  # may raise StopIteration
--> 561         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    562         if self._pin_memory:
    563             data = _utils.pin_memory.pin_memory(data)

/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index)
     42     def fetch(self, possibly_batched_index):
     43         if self.auto_collation:
---> 44             data = [self.dataset[idx] for idx in possibly_batched_index]
     45         else:
     46             data = self.dataset[possibly_batched_index]

/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py in <listcomp>(.0)
     42     def fetch(self, possibly_batched_index):
     43         if self.auto_collation:
---> 44             data = [self.dataset[idx] for idx in possibly_batched_index]
     45         else:
     46             data = self.dataset[possibly_batched_index]

/run/determined/workdir/combined.py in __getitem__(self, indicies)
     14         data_idx = indicies[1]
     15         print(indicies)
---> 16         return self.datasets[dataset_idx].__getitem__(data_idx)
     17 
     18 # class that will take in multiple samplers and output batches from a single dataset at a time

TypeError: list indices must be integers or slices, not tuple

it looks like its trying to call getitem() with a list of tuples as input and not calling it one time for each entry. how should i solve this?

nevermind, i fixed it! i should have done the following:

return DataLoader(train_data_combined, batch_sampler=sampler)