apologies if I am not being clear. But I think this might be what I’m looking for, where the labeling is re-done assuming mutually exclusive labels for each data set:
class ConcatDataset(Dataset):
"""
ref: https://discuss.pytorch.org/t/concat-image-datasets-with-different-size-and-number-of-channels/36362/12
"""
def __init__(self, datasets: list[Dataset]):
"""
"""
# I think concat is better than passing data to a self.data = x obj since concat likely using the getitem method of the passed dataset and thus if the passed dataset doesnt put all the data in memory concat won't either
self.concat_datasets = torch.utils.data.ConcatDataset(datasets)
# maps a class label to a list of sample indices with that label.
self.labels_to_indices = defaultdict(list)
# maps a sample index to its corresponding class label.
self.indices_to_labels = defaultdict(None)
# - do the relabeling
offset: int = 0
new_idx: int = 0
for dataset_idx, dataset in enumerate(datasets):
assert len(dataset) == len(self.concat_datasets.datasets[dataset_idx])
assert dataset == self.concat_datasets.datasets[dataset_idx]
for x, y in dataset:
y = int(y)
new_label = y + offset
self.indices_to_labels[new_idx] = new_label
_x, _y = self.concat_datasets[new_idx]
_y = int(_y)
assert y == _y
assert torch.equal(x, _x)
self.labels_to_indices[new_label] = new_idx
num_labels_for_current_dataset: int = max([y for _, y in dataset])
offset += num_labels_for_current_dataset
new_idx += 1
assert len(self.indices_to_labels.keys()) == len(self.concat_datasets)
# contains the list of labels from 0 - total num labels after concat
self.labels = range(offset)
self.target_transform = lambda data: torch.tensor(data, dtype=torch.int)
def __len__(self):
return len(self.concat_datasets)
def __getitem__(self, idx: int) -> tuple[Tensor, Tensor]:
x = self.concat_datasets[idx]
y = self.indices_to_labels[idx]
if self.target_transform is not None:
y = self.target_transform(y)
return x, y
does this look right to you? Need to think how to test it.
assertion fails:
assert torch.equal(x, _x)
AssertionError
python-BaseException
code
def check_xs_align_cifar100():
from pathlib import Path
root = Path("~/data/").expanduser()
# root = Path(".").expanduser()
train = torchvision.datasets.CIFAR100(root=root, train=True, download=True)
test = torchvision.datasets.CIFAR100(root=root, train=False, download=True)
concat = ConcatDataset([train, test])
print(f'{len(concat)=}')
print(f'{len(concat.labels)=}')
cifar100 doesn’t test the relabling of disjoint union of data sets but the x’s should’ve aligned…but they didnt
after converting images to tensor comparison still fails @ptrblck :
img2tensor: Callable = torchvision.transforms.ToTensor()
x, _x = img2tensor(x), img2tensor(_x)
assert torch.equal(x, _x), f'Error for some reason, got: {data_idx=}, {x.norm()=}, {_x.norm()=}, {x=}, {_x=}'