Concat image datasets with different size and number of channels

Good evening everyone,

I got a problem with loading multiple datasets and did not find a solution so far.
For loading multiple datasets into one dataloader I can use the ConcatDataset class, but how do I concatenate e.g. CIFAR10 and MNIST. For this I would have CIFAR10 to resize and convert to grayscale without applying the transformations to MNIST.

Any ideas on how to do that?


You could specify the transformations for each dataset.
This code should work:

mnist = datasets.MNIST(

cifar = datasets.CIFAR10(

dataset = ConcatDataset((mnist, cifar))

As you can see transforms.Resize and transforms.Grayscale was only applied for the CIFAR10 dataset.

1 Like

Oh yes of course, I am really a blockhead. :grin:

Can you concatenate the targets using this approach? If not, how do we go about it?

ConcatDataset will concatenate all passed Datasets keeping their samples and targets.

may I ask what this means? I’d lake to take the union for the labels and being relabeled from from scratch from 0 to len_dataset1+len_dataset2.

easiest solution to what I want is to do use this: Does Concatenate Datasets preserve class labels and indices - #12 by Brando_Miranda by using learn2learn’s union of data sets.

ConcatDataset will just iterate both datasets as seen here.
Here is a small example:

dataset1 = TensorDataset(torch.arange(5).float().view(-1, 1))
dataset2 = TensorDataset(torch.arange(5).float().view(-1, 1) + 0.1)

dataset =, dataset2))

for d in dataset:

# (tensor([0.]),)
# (tensor([1.]),)
# (tensor([2.]),)
# (tensor([3.]),)
# (tensor([4.]),)
# (tensor([0.1000]),)
# (tensor([1.1000]),)
# (tensor([2.1000]),)
# (tensor([3.1000]),)
# (tensor([4.1000]),)

loader = DataLoader(dataset, batch_size=2)
for d in loader:
# [tensor([[0.],
#         [1.]])]
# [tensor([[2.],
#         [3.]])]
# [tensor([[4.0000],
#         [0.1000]])]
# [tensor([[1.1000],
#         [2.1000]])]
# [tensor([[3.1000],
#         [4.1000]])]

loader = DataLoader(dataset, batch_size=2, shuffle=True)
for d in loader:

# [tensor([[0.1000],
#         [1.1000]])]
# [tensor([[3.1000],
#         [2.0000]])]
# [tensor([[1.],
#         [0.]])]
# [tensor([[4.],
#         [3.]])]
# [tensor([[4.1000],
#         [2.1000]])]

Is this what you are referring to as “union” or are samples duplicated in both datasets which you want to remove?

no, I just mean concatenate and re-index the class. Basically assumes mutually exclusive labels.

In that case my code should work and shows this behavior.

but you still need to produce the right labels. It’s not hard I suppose but if the data loader is random we want the classes to be consistently & correct labeled.

The datasets are indexed separately as seen in the linked code, so they cannot be mixed up.
You can use my code snippet to add targets and would see that the correspondence is not broken (it would otherwise make this class quite useless):

dataset1 = TensorDataset(torch.arange(5).float().view(-1, 1), torch.zeros(5, 1))
dataset2 = TensorDataset(torch.arange(5).float().view(-1, 1) + 0.1, torch.ones(5, 1))

dataset =, dataset2))

loader = DataLoader(dataset, batch_size=2, shuffle=True)
for d in loader:

apologies if I am not being clear. But I think this might be what I’m looking for, where the labeling is re-done assuming mutually exclusive labels for each data set:

class ConcatDataset(Dataset):


    def __init__(self, datasets: list[Dataset]):
        # I think concat is better than passing data to a = x obj since concat likely using the getitem method of the passed dataset and thus if the passed dataset doesnt put all the data in memory concat won't either
        self.concat_datasets =
        # maps a class label to a list of sample indices with that label.
        self.labels_to_indices = defaultdict(list)
        # maps a sample index to its corresponding class label.
        self.indices_to_labels = defaultdict(None)
        # - do the relabeling
        offset: int = 0
        new_idx: int = 0
        for dataset_idx, dataset in enumerate(datasets):
            assert len(dataset) == len(self.concat_datasets.datasets[dataset_idx])
            assert dataset == self.concat_datasets.datasets[dataset_idx]
            for x, y in dataset:
                y = int(y)
                new_label = y + offset
                self.indices_to_labels[new_idx] = new_label
                _x, _y = self.concat_datasets[new_idx]
                _y = int(_y)
                assert y == _y
                assert torch.equal(x, _x)
                self.labels_to_indices[new_label] = new_idx
            num_labels_for_current_dataset: int = max([y for _, y in dataset])
            offset += num_labels_for_current_dataset
            new_idx += 1
        assert len(self.indices_to_labels.keys()) == len(self.concat_datasets)
        # contains the list of labels from 0 - total num labels after concat
        self.labels = range(offset)
        self.target_transform = lambda data: torch.tensor(data,

    def __len__(self):
        return len(self.concat_datasets)

    def __getitem__(self, idx: int) -> tuple[Tensor, Tensor]:
        x = self.concat_datasets[idx]
        y = self.indices_to_labels[idx]
        if self.target_transform is not None:
            y = self.target_transform(y)
        return x, y

does this look right to you? Need to think how to test it.

assertion fails:

    assert torch.equal(x, _x)


def check_xs_align_cifar100():
    from pathlib import Path

    root = Path("~/data/").expanduser()
    # root = Path(".").expanduser()
    train = torchvision.datasets.CIFAR100(root=root, train=True, download=True)
    test = torchvision.datasets.CIFAR100(root=root, train=False, download=True)

    concat = ConcatDataset([train, test])

cifar100 doesn’t test the relabling of disjoint union of data sets but the x’s should’ve aligned…but they didnt

after converting images to tensor comparison still fails @ptrblck :

    img2tensor: Callable = torchvision.transforms.ToTensor()
    x, _x = img2tensor(x), img2tensor(_x)
    assert torch.equal(x, _x), f'Error for some reason, got: {data_idx=}, {x.norm()=}, {_x.norm()=}, {x=}, {_x=}'

answer here I think: python - Why don't the images align when concatenating two data sets in pytorch using - Stack Overflow