Concat image datasets with different size and number of channels

apologies if I am not being clear. But I think this might be what I’m looking for, where the labeling is re-done assuming mutually exclusive labels for each data set:

class ConcatDataset(Dataset):
    """

    ref: https://discuss.pytorch.org/t/concat-image-datasets-with-different-size-and-number-of-channels/36362/12
    """

    def __init__(self, datasets: list[Dataset]):
        """
        """
        # I think concat is better than passing data to a self.data = x obj since concat likely using the getitem method of the passed dataset and thus if the passed dataset doesnt put all the data in memory concat won't either
        self.concat_datasets = torch.utils.data.ConcatDataset(datasets)
        # maps a class label to a list of sample indices with that label.
        self.labels_to_indices = defaultdict(list)
        # maps a sample index to its corresponding class label.
        self.indices_to_labels = defaultdict(None)
        # - do the relabeling
        offset: int = 0
        new_idx: int = 0
        for dataset_idx, dataset in enumerate(datasets):
            assert len(dataset) == len(self.concat_datasets.datasets[dataset_idx])
            assert dataset == self.concat_datasets.datasets[dataset_idx]
            for x, y in dataset:
                y = int(y)
                new_label = y + offset
                self.indices_to_labels[new_idx] = new_label
                _x, _y = self.concat_datasets[new_idx]
                _y = int(_y)
                assert y == _y
                assert torch.equal(x, _x)
                self.labels_to_indices[new_label] = new_idx
            num_labels_for_current_dataset: int = max([y for _, y in dataset])
            offset += num_labels_for_current_dataset
            new_idx += 1
        assert len(self.indices_to_labels.keys()) == len(self.concat_datasets)
        # contains the list of labels from 0 - total num labels after concat
        self.labels = range(offset)
        self.target_transform = lambda data: torch.tensor(data, dtype=torch.int)

    def __len__(self):
        return len(self.concat_datasets)

    def __getitem__(self, idx: int) -> tuple[Tensor, Tensor]:
        x = self.concat_datasets[idx]
        y = self.indices_to_labels[idx]
        if self.target_transform is not None:
            y = self.target_transform(y)
        return x, y

does this look right to you? Need to think how to test it.


assertion fails:

    assert torch.equal(x, _x)
AssertionError
python-BaseException

code

def check_xs_align_cifar100():
    from pathlib import Path

    root = Path("~/data/").expanduser()
    # root = Path(".").expanduser()
    train = torchvision.datasets.CIFAR100(root=root, train=True, download=True)
    test = torchvision.datasets.CIFAR100(root=root, train=False, download=True)

    concat = ConcatDataset([train, test])
    print(f'{len(concat)=}')
    print(f'{len(concat.labels)=}')

cifar100 doesn’t test the relabling of disjoint union of data sets but the x’s should’ve aligned…but they didnt


after converting images to tensor comparison still fails @ptrblck :

    img2tensor: Callable = torchvision.transforms.ToTensor()
    x, _x = img2tensor(x), img2tensor(_x)
    assert torch.equal(x, _x), f'Error for some reason, got: {data_idx=}, {x.norm()=}, {_x.norm()=}, {x=}, {_x=}'