Does Concatenate Datasets preserve class labels and indices

Let’s say I have 2 image folder datasets and I want to concatenate them.

The first dataset has 100 images with 2 equal classes: “Dog” and “Cat” with class indices 0 and 1
The second dataset has 120 images with 3 equal classes: “Dog”, “Cat” and “Pig” with class indices 0, 1 and 2

When I concatenate the two datasets with torch.utils.data.ConcatDataset(), will I get a dataset with 90 dog and 90 cat images, and 40 pig images? Or does it treat the two dog and cat labels from the two datasets as different classes - so I actually end up with 5 classes? I didn’t find an easy way to check because the ConcatDataset class doesn’t have a classes or class_to_idx method.

And also what of the case if a third dataset had only classes “dog” and “pig” and I concatenated it with the first and second? Does ConcatDataset() map that pig should have class index 2 instead of 1?

ConcatDataset will not create a mapping, but just index the passed Datasets.
Each Dataset should make sure to yield the “right” labels.
E.g. in your second use case, you should make sure that dataset1 only yields samples with the class labels 0 and 1 (dog and cat), while dataset2 should only yield 0 and 2 (dog and pig).

The mapping is defined by your Dataset implementation.

Note that if you are using e.g. ImageFolder, the mapping will be created based on the folders, so I would not recommend using this approach, if your dataset folders do not contains the same classes.

1 Like

I see, thanks. For a quick hack I ended up creating empty folders with all required classes in my image dataset root.

how to get the classes if i have three folder with same class inside the folder?

If you are using ImageFolder, you can access its dataset.class_to_idx attribute to see the mapping between the folders and the class indices.
I’m not sure if I misunderstand the question, but do all three folders contain images from the same class (one class only)?

no, there is two class…
Thank you for respond :smile:

Hi, what could be right approach to concatenate 2 datasets having different classes/labels. For example dataset D1 has folders for “cat” and “dog” whereas dataset D2 has folders like “elephant” and “lion”. Right now I created empty folders named “elephant” and “lion” in dataset D1 to preserve the labels and vice versa before using ImageFolder and ConcatenateDataset as bellow:
ds = torch.utils.data.ConcatDataset(
[datasets.ImageFolder(’./data/D1/train’, transform),
datasets.ImageFolder(’./data/D2/train’, transform)]
)

Your approach sounds alright, if you want to use the ImageFolder datasets.
Alternatively, you could write a custom Dataset and return the corresponding targets for both datasets, which wouldn’t rely on creating empty folders.

Thanks! It will be great if there is some pointer to some examples for how to do it as I am new to pytorch.

This tutorial shows how to write a custom Dataset and might be helpful. :wink:

ok so if I want to:

I’d lake to take the union for the labels and being relabeled from from scratch from 0 to len_dataset1+len_dataset2-1.

Then I have to implement my own “version” of union/concate/merge data set that takes into account this. Likely re-implementing the __getitem__(idx: int): function to something like this:

def __getitem__(self, index: int):
    # leave the sampled labels of data set 1 as is
    img, target = self.mnist[index], int(self.mnist.targets[index])

    # to the sampled labels of data set 2 add the number of
    img, target = self.cifar10[index], int(self.cifar10.targets[index]) + len(self.mnist)
    return ...

darn this isn’t quite right…based on the index the data set I implement should know what is the right mapping…also this doesn’t work for an arbitrary union of data sets of course…then likely one needs to bisect function to find in which interval the idx is in the you know how many lens of data sets you need to add…

I think the easiest is to wrap your normal data set into learn2learn’s metadataset then pass it to their Union data set.

        train = torchvision.datasets.CIFARFS(root="/tmp/mnist", mode="train")
        train = l2l.data.MetaDataset(train)
        valid = torchvision.datasets.CIFARFS(root="/tmp/mnist", mode="validation")
        valid = l2l.data.MetaDataset(valid)
        test = torchvision.datasets.CIFARFS(root="/tmp/mnist", mode="test")
        test = l2l.data.MetaDataset(test)
        union = UnionMetaDataset([train, valid, test])
        assert len(union.labels) == 100
class UnionMetaDataset(MetaDataset):
    """
    **Description**
    
        Takes multiple MetaDataests and constructs their union.
    
        Note: The labels of all datasets are remapped to be in consecutive order.
        (i.e. the same label in two datasets will be to two different labels in the union)
    
        **Arguments**
    
        * **datasets** (list of Dataset) -  A list of torch Datasets.

link: learn2learn.data - learn2learn

actually it’s easier to:

Actually, it’s likely easier to preprocess the data points indices to map to the label required label (as you loop through each data set you’d know this value easily and keep a single counter) – instead of bisecting.

1 Like

Hi @ptrblck I concatenated 3 datasets for data augmentation. The images were taken from the same path, so the three datasets have the same four labels. Is there a method o attribute for the ConcatDataset method to view the labels of the concatenated dataset like the ones for Dataset method.Further, I can use a Counter to check how many images are in the train and val folders.

for example

from collections import Counter
print(f'Classes: {train_loader.dataset.class_to_idx}')
num_classes_train = Counter(train_loader.dataset.targets)
num_classes_val = Counter(val_loader.dataset.targets)
print(f'Train Loader: {num_classes_train}')
print(f'Val Loader: {num_classes_val}')

Thanks in advance!
Pablo

You can access the internal datasets via the .datasets attribute as seen here:

dataset1 = datasets.MNIST(root="/data", download=False)
dataset1.targets

dataset2 = datasets.CIFAR10(root="/data", download=False)
dataset2.targets

dataset = torch.utils.data.ConcatDataset((dataset1, dataset2))
dataset.datasets[0].targets
dataset.datasets[1].targets

Ok so I have to build three counter for each dataset and sum up them. As far as I have seen ConcatDataset have not the same methods and attributes than Dataset. Thanks!