I wanted to concatenate multiple data sets where the labels are disjoint (so don’t share labels). I did:
class ConcatDataset(Dataset):
"""
ref: https://discuss.pytorch.org/t/concat-image-datasets-with-different-size-and-number-of-channels/36362/12
"""
def __init__(self, datasets: list[Dataset]):
"""
"""
# I think concat is better than passing data to a self.data = x obj since concat likely using the getitem method of the passed dataset and thus if the passed dataset doesnt put all the data in memory concat won't either
self.concat_datasets = torch.utils.data.ConcatDataset(datasets)
# maps a class label to a list of sample indices with that label.
self.labels_to_indices = defaultdict(list)
# maps a sample index to its corresponding class label.
self.indices_to_labels = defaultdict(None)
# - do the relabeling
offset: int = 0
new_idx: int = 0
for dataset_idx, dataset in enumerate(datasets):
assert len(dataset) == len(self.concat_datasets.datasets[dataset_idx])
assert dataset == self.concat_datasets.datasets[dataset_idx]
for x, y in dataset:
y = int(y)
_x, _y = self.concat_datasets[new_idx]
_y = int(_y)
# assert y == _y
assert torch.equal(x, _x)
new_label = y + offset
self.indices_to_labels[new_idx] = new_label
self.labels_to_indices[new_label] = new_idx
num_labels_for_current_dataset: int = max([y for _, y in dataset])
offset += num_labels_for_current_dataset
new_idx += 1
assert len(self.indices_to_labels.keys()) == len(self.concat_datasets)
# contains the list of labels from 0 - total num labels after concat
self.labels = range(offset)
self.target_transform = lambda data: torch.tensor(data, dtype=torch.int)
def __len__(self):
return len(self.concat_datasets)
def __getitem__(self, idx: int) -> tuple[Tensor, Tensor]:
x = self.concat_datasets[idx]
y = self.indices_to_labels[idx]
if self.target_transform is not None:
y = self.target_transform(y)
return x, y
but it doesn’t even work to align the x images (so never mind if my relabling works!). Why?
def check_xs_align_cifar100():
from pathlib import Path
root = Path("~/data/").expanduser()
# root = Path(".").expanduser()
train = torchvision.datasets.CIFAR100(root=root, train=True, download=True)
test = torchvision.datasets.CIFAR100(root=root, train=False, download=True)
concat = ConcatDataset([train, test])
print(f'{len(concat)=}')
print(f'{len(concat.labels)=}')
error
Files already downloaded and verified
Files already downloaded and verified
Traceback (most recent call last):
File "/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/pydevd.py", line 1491, in _exec
pydev_imports.execfile(file, globals, locals) # execute the script
File "/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
exec(compile(contents+"\n", file, 'exec'), glob, loc)
File "/Users/brandomiranda/ultimate-utils/ultimate-utils-proj-src/uutils/torch_uu/dataset/concate_dataset.py", line 405, in <module>
check_xs_align()
File "/Users/brandomiranda/ultimate-utils/ultimate-utils-proj-src/uutils/torch_uu/dataset/concate_dataset.py", line 391, in check_xs_align
concat = ConcatDataset([train, test])
File "/Users/brandomiranda/ultimate-utils/ultimate-utils-proj-src/uutils/torch_uu/dataset/concate_dataset.py", line 71, in __init__
assert torch.equal(x, _x)
TypeError: equal(): argument 'input' (position 1) must be Tensor, not Image
python-BaseException
Bonus: let me know if relabeling is correct please.
related discussion: Concat image datasets with different size and number of channels - #12 by Brando_Miranda
Edit 1: PIL comparison fails
I did a PIL image comparison according to Compare images Python PIL - Stack Overflow but it failed:
Traceback (most recent call last):
File "/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/pydevd.py", line 1491, in _exec
pydev_imports.execfile(file, globals, locals) # execute the script
File "/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
exec(compile(contents+"\n", file, 'exec'), glob, loc)
File "/Users/brandomiranda/ultimate-utils/ultimate-utils-proj-src/uutils/torch_uu/dataset/concate_dataset.py", line 419, in <module>
check_xs_align_cifar100()
File "/Users/brandomiranda/ultimate-utils/ultimate-utils-proj-src/uutils/torch_uu/dataset/concate_dataset.py", line 405, in check_xs_align_cifar100
concat = ConcatDataset([train, test])
File "/Users/brandomiranda/ultimate-utils/ultimate-utils-proj-src/uutils/torch_uu/dataset/concate_dataset.py", line 78, in __init__
assert diff.getbbox(), f'comparison of imgs failed: {diff.getbbox()=}'
AssertionError: comparison of imgs failed: diff.getbbox()=None
python-BaseException
diff
PyDev console: starting.
<PIL.Image.Image image mode=RGB size=32x32 at 0x7FBE897A21C0>
code comparison:
diff = ImageChops.difference(x, _x) # https://stackoverflow.com/questions/35176639/compare-images-python-pil
assert diff.getbbox(), f'comparison of imgs failed: {diff.getbbox()=}'
this also failed:
assert list(x.getdata()) == list(_x.getdata()), f'\n{list(x.getdata())=}, \n{list(_x.getdata())=}'
AssertionError: ...long msg...
assert statement was:
assert list(x.getdata()) == list(_x.getdata()), f'\n{list(x.getdata())=}, \n{list(_x.getdata())=}'
Edit 2: Tensor comparison Fails
I tried to convert images to tensors but it still fails:
AssertionError: Error for some reason, got: data_idx=1, x.norm()=tensor(45.9401), _x.norm()=tensor(33.9407), x=tensor([[[1.0000, 0.9922, 0.9922, ..., 0.9922, 0.9922, 1.0000],
code:
class ConcatDataset(Dataset):
"""
ref:
- https://discuss.pytorch.org/t/concat-image-datasets-with-different-size-and-number-of-channels/36362/12
- https://stackoverflow.com/questions/73913522/why-dont-the-images-align-when-concatenating-two-data-sets-in-pytorch-using-tor
"""
def __init__(self, datasets: list[Dataset]):
"""
"""
# I think concat is better than passing data to a self.data = x obj since concat likely using the getitem method of the passed dataset and thus if the passed dataset doesnt put all the data in memory concat won't either
self.concat_datasets = torch.utils.data.ConcatDataset(datasets)
# maps a class label to a list of sample indices with that label.
self.labels_to_indices = defaultdict(list)
# maps a sample index to its corresponding class label.
self.indices_to_labels = defaultdict(None)
# - do the relabeling
img2tensor: Callable = torchvision.transforms.ToTensor()
offset: int = 0
new_idx: int = 0
for dataset_idx, dataset in enumerate(datasets):
assert len(dataset) == len(self.concat_datasets.datasets[dataset_idx])
assert dataset == self.concat_datasets.datasets[dataset_idx]
for data_idx, (x, y) in enumerate(dataset):
y = int(y)
# - get data point from concataned data set (to compare with the data point from the data set list)
_x, _y = self.concat_datasets[new_idx]
_y = int(_y)
# - sanity check concatanted data set aligns with the list of datasets
# assert y == _y
# from PIL import ImageChops
# diff = ImageChops.difference(x, _x) # https://stackoverflow.com/questions/35176639/compare-images-python-pil
# assert diff.getbbox(), f'comparison of imgs failed: {diff.getbbox()=}'
# assert list(x.getdata()) == list(_x.getdata()), f'\n{list(x.getdata())=}, \n{list(_x.getdata())=}'
# tensor comparison
x, _x = img2tensor(x), img2tensor(_x)
print(f'{data_idx=}, {x.norm()=}, {_x.norm()=}')
assert torch.equal(x, _x), f'Error for some reason, got: {data_idx=}, {x.norm()=}, {_x.norm()=}, {x=}, {_x=}'
# - relabling
new_label = y + offset
self.indices_to_labels[new_idx] = new_label
self.labels_to_indices[new_label] = new_idx
num_labels_for_current_dataset: int = max([y for _, y in dataset])
offset += num_labels_for_current_dataset
new_idx += 1
assert len(self.indices_to_labels.keys()) == len(self.concat_datasets)
# contains the list of labels from 0 - total num labels after concat
self.labels = range(offset)
self.target_transform = lambda data: torch.tensor(data, dtype=torch.int)
def __len__(self):
return len(self.concat_datasets)
def __getitem__(self, idx: int) -> tuple[Tensor, Tensor]:
x = self.concat_datasets[idx]
y = self.indices_to_labels[idx]
if self.target_transform is not None:
y = self.target_transform(y)
return x, y
reddit link: https://www.reddit.com/r/pytorch/comments/xurnu9/why_dont_the_images_align_when_concatenating_two/
related: Concat image datasets with different size and number of channels - #12 by Brando_Miranda