Parallel transformed dataset loading

I am trying to implement Information Invariant Clustering and it requires presenting a network with 2 transforms of the same image. I didn’t understand their code exactly. If I always load the data in order, I believe I can just have two loaders with the different transforms and load each batch together but I want the dataloader to be able to load random images as usual.

My thought was to create a custom dataset class that performs the transforms on the incoming data:

def Transformed_Dataset_With_Indices(cls):
    """
    Modifies the given Dataset class to return a tuple data, target, transformed data, index
    instead of just data, target.
    Usage:
        MNISTWithIndices = dataset_with_indices(MNIST)
        dataset = MNISTWithIndices('~/datasets/mnist')

    Modified from https://discuss.pytorch.org/t/how-to-retrieve-the-sample-indices-of-a-mini-batch/7948/19
    By Cassidy Laidlaw
    And (perturb_imagedata) by Rudolf A. Braun from https://github.com/RuABraun/phone-clustering
    """
    def __init__(self, normalizer=None,
                    transform_list=[transforms.RandomCrop(28, padding=4),
                                             transforms.ColorJitter(brightness=(0.2, 0.75,)),
                                             transforms.RandomAffine(1, (0.1, 0.1,), (0.95, 1,))]):
        self.norm = normalizer
        if self.norm is None:
            self.norm = lambda a: a
        self.trans = transforms.Compose(transform_list)

    def __getitem__(self):
        data, target = cls.__getitem__(self, index)
        transformed_data = self.norm(tf.to_tensor(self.trans(tf.to_pil_image(data))))
        return data, target, index, transformed_data

    return type(cls.__name__, (cls,), {
        '__getitem__': __getitem__,
    })

But I can’t get it to use the parameters if I try to use it with:

MNISTWithIndices = Transformed_Dataset_With_Indices(torchvision.datasets.MNIST, normalize_data=transforms.Normalize((0.1307,), (0.3081,)))

Gives and error:

TypeError: Transformed_Dataset_With_Indices() got an unexpected keyword argument 'normalize_data'

I also tried:

def Transformed_Dataset_With_Indices(cls):
    """
    Modifies the given Dataset class to return a tuple data, target, transformed data, index
    instead of just data, target.
    Usage:
        MNISTWithIndices = dataset_with_indices(MNIST)
        dataset = MNISTWithIndices('~/datasets/mnist')

    Modified from https://discuss.pytorch.org/t/how-to-retrieve-the-sample-indices-of-a-mini-batch/7948/19
    By Cassidy Laidlaw
    And (perturb_imagedata) by Rudolf A. Braun from https://github.com/RuABraun/phone-clustering
    """
    def __getitem__(self, index, normalizer=transforms.Normalize((0.1307,), (0.3081,)),
                    transform_list=[transforms.RandomCrop(28, padding=4),
                                             transforms.ColorJitter(brightness=(0.2, 0.75,)),
                                             transforms.RandomAffine(1, (0.1, 0.1,), (0.95, 1,))]):
        data, target = cls.__getitem__(self, index)
        trans = transforms.Compose(transform_list)
        norm = normalizer
        if norm is None:
            norm = lambda a: a
        transformed_data = norm(tf.to_tensor(trans(tf.to_pil_image(data))))
        return data, target, index, transformed_data

    return type(cls.__name__, (cls,), {
        '__getitem__': __getitem__,
    })

but it gives the same error.

You can test it with this code:

MNISTWithIndices = Transformed_Dataset_With_Indices(torchvision.datasets.MNIST)
train_set = MNISTWithIndices('../Neurobaby/Data', train=True, download=True,
                                       transform=transforms.Compose([
                                           transforms.ToTensor()
                                          # , transforms.Normalize((0.1307,), (0.3081,))
                                       ]))
test_set = torchvision.datasets.MNIST('../Neurobaby/Data', train=False, download=True,
                                      transform=transforms.Compose([
                                          transforms.ToTensor()
                                         # , transforms.Normalize((0.1307,), (0.3081,))
                                      ]))

x, _, _, transformed_x = train_set[7777] # x is now a torch.Tensor
plt.imshow(x.numpy()[0], cmap='gray')
plt.show()
plt.imshow(transformed_x.numpy()[0], cmap='gray')
plt.show()

Any ideas on how to pass normalizations and tensor_lists to the custom dataset please?

It will make it a lot easier with different data sets.

Thanks in advance.