I am trying to implement Information Invariant Clustering and it requires presenting a network with 2 transforms of the same image. I didn’t understand their code exactly. If I always load the data in order, I believe I can just have two loaders with the different transforms and load each batch together but I want the dataloader to be able to load random images as usual.
My thought was to create a custom dataset class that performs the transforms on the incoming data:
def Transformed_Dataset_With_Indices(cls):
"""
Modifies the given Dataset class to return a tuple data, target, transformed data, index
instead of just data, target.
Usage:
MNISTWithIndices = dataset_with_indices(MNIST)
dataset = MNISTWithIndices('~/datasets/mnist')
Modified from https://discuss.pytorch.org/t/how-to-retrieve-the-sample-indices-of-a-mini-batch/7948/19
By Cassidy Laidlaw
And (perturb_imagedata) by Rudolf A. Braun from https://github.com/RuABraun/phone-clustering
"""
def __init__(self, normalizer=None,
transform_list=[transforms.RandomCrop(28, padding=4),
transforms.ColorJitter(brightness=(0.2, 0.75,)),
transforms.RandomAffine(1, (0.1, 0.1,), (0.95, 1,))]):
self.norm = normalizer
if self.norm is None:
self.norm = lambda a: a
self.trans = transforms.Compose(transform_list)
def __getitem__(self):
data, target = cls.__getitem__(self, index)
transformed_data = self.norm(tf.to_tensor(self.trans(tf.to_pil_image(data))))
return data, target, index, transformed_data
return type(cls.__name__, (cls,), {
'__getitem__': __getitem__,
})
But I can’t get it to use the parameters if I try to use it with:
MNISTWithIndices = Transformed_Dataset_With_Indices(torchvision.datasets.MNIST, normalize_data=transforms.Normalize((0.1307,), (0.3081,)))
Gives and error:
TypeError: Transformed_Dataset_With_Indices() got an unexpected keyword argument 'normalize_data'
I also tried:
def Transformed_Dataset_With_Indices(cls):
"""
Modifies the given Dataset class to return a tuple data, target, transformed data, index
instead of just data, target.
Usage:
MNISTWithIndices = dataset_with_indices(MNIST)
dataset = MNISTWithIndices('~/datasets/mnist')
Modified from https://discuss.pytorch.org/t/how-to-retrieve-the-sample-indices-of-a-mini-batch/7948/19
By Cassidy Laidlaw
And (perturb_imagedata) by Rudolf A. Braun from https://github.com/RuABraun/phone-clustering
"""
def __getitem__(self, index, normalizer=transforms.Normalize((0.1307,), (0.3081,)),
transform_list=[transforms.RandomCrop(28, padding=4),
transforms.ColorJitter(brightness=(0.2, 0.75,)),
transforms.RandomAffine(1, (0.1, 0.1,), (0.95, 1,))]):
data, target = cls.__getitem__(self, index)
trans = transforms.Compose(transform_list)
norm = normalizer
if norm is None:
norm = lambda a: a
transformed_data = norm(tf.to_tensor(trans(tf.to_pil_image(data))))
return data, target, index, transformed_data
return type(cls.__name__, (cls,), {
'__getitem__': __getitem__,
})
but it gives the same error.
You can test it with this code:
MNISTWithIndices = Transformed_Dataset_With_Indices(torchvision.datasets.MNIST)
train_set = MNISTWithIndices('../Neurobaby/Data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor()
# , transforms.Normalize((0.1307,), (0.3081,))
]))
test_set = torchvision.datasets.MNIST('../Neurobaby/Data', train=False, download=True,
transform=transforms.Compose([
transforms.ToTensor()
# , transforms.Normalize((0.1307,), (0.3081,))
]))
x, _, _, transformed_x = train_set[7777] # x is now a torch.Tensor
plt.imshow(x.numpy()[0], cmap='gray')
plt.show()
plt.imshow(transformed_x.numpy()[0], cmap='gray')
plt.show()
Any ideas on how to pass normalizations and tensor_lists to the custom dataset please?
It will make it a lot easier with different data sets.
Thanks in advance.