Disable transform in validation dataset


I am currently implementing a modified Resnet CNN model. The data is in H5F format. I read the entire data into memory and train/test split it. The following is a simple example.

class mydataset(Dataset):
    def __init__(self, 
        super(mydataset, self).__init__()
        self.transforms = transforms
        self.file = h5py.File(input_file, "r")
        self.label = ... # read label to memory 
        self.img = .... # read img to memory

    def __len__(self):
        return len(self.label)
    def __getitem__(self, idx):
        img = self.img[idx]
        label = self.label[idx]
        if self.transforms is not None:
            img = self.transforms(img)
        return img, label

# Then I do train & validation split as following 
def pytorch_train_test_split(dataset, 
    Split train and test dataset 
    data (torch.Dataset) : a pytroch dataset 
    random_state (int) : a random seed 
    batch_size (int) : a size of each batch 
    test_size (float) : the size of test dataset 
    shuffle (bool) : if true, shuffle the dataset

    train dataloader, test dataloader 
    dataset_size = dataset.__len__()
    indices = list(range(dataset_size))
    split = int(np.floor(test_size * dataset_size))
    if shuffle:
    train_indices, test_indices = indices[split:], indices[:split]
    # Creating data samplers and loaders:
    train_sampler = SubsetRandomSampler(train_indices)
    test_sampler = SubsetRandomSampler(test_indices)
    train_loader = DataLoader(dataset, batch_size=batch_size,
    test_loader = DataLoader(dataset, batch_size=512,
    return train_loader, test_loader

# main() 
# apply transform 
transforms = transforms.Compose([transforms.ToPILImage(),

# read dataset
my_data = mydataset("somefile.h5", transforms)

# generate train and valid data loader 
train_loader, valid_loader = pytorch_train_test_split(my_data,

However, there is a problem by doing this way. We normally do not validate/test data with RandomFlip (or randomscale or randomcrop…). How can I disable transform on validation data loader? I know that we can write a script to split data into train and validation data first then create two datasets with different transforms, but is there any other way to do this?

Any help would be appreciated.

Thank you.

1 Like

The traditional way of doing it is: passing an additional argument to the custom dataset class (e.g. transform=False) and setting it to True` only for the training dataset.

Then in the code, add a check if self.transform is True:, and then perform the augmentation as you currently do!

Thank you. That will work if I already separate training and validation data. However, I read the entire data first then split and wrap them by data loader. I have no way to check if transforms == true.


My bad: I didn’t see the file was scrollable and therefore didn’t see the second part of the code. Thanks for not jumping at me :slight_smile:

In that case, I fear it’s too complex to do it your way. You would need an additional argument to __getitem__(self, idx, transform=False), but then who calls __getitem__ and how do you pass the additional argument to it? My guess is that you would need to pass the transform argument all the way from the DataLoader down to the thing actually iterating on the dataset, which gets really complicated with all the multiprocessing aspects of DataLoader…

IMHO, the easiest way is to split the dataset first… Maybe someone else can provide a quick hack for this instead of overriding a lot of code?


You could create a “wrapper” dataset like this (not tested!):

class WrapperDataset(Dataset):
    def __init__(self, dataset, indices, transform=None):
         self.dataset = dataset
         self.indices = indices
         self.transform = transform
    def __len__(self):
        return len(self.indices)
    def __getitem__(self, idx):
        new_idx = self.indices[idx]
        data = self.dataset[new_idx]
        if self.transform is not None:
         return data

This way you can create two separate datasets for training/testing based on the indices you already have available.