Changing transforms after creating a dataset

i’m using torchvision.datasets.ImageFolder (which takes transform as input) to read my data, then i split it to train and test sets using torch.utils.data.Subset. until now i applied the same transforms to all images, doesn’t matter whether they’re train or test, but now i want to change it.
is it possible to do so without writing a custom dataset? i don’t want to write a new ImageFolder like class just for that.
this happens because the images folder contains all the data and then i split it randomly after loading.
btw, is it also not smart to do so because of statistics consistency (and should be using the same train/test split every time)?

1 Like

Hey, I do have the same question, I define my train dataset first, which have transformations, then split it into train, val, test dataset.
I did try

test_dataset.transform = new_transforms

But, it did not work

I would like to do this too

1 Like

@ptrblck apologies for the direct ping but how does one correctly change the transforms of a data set?

e.g. use case it to get the entire train set with the train transform but then if you split it the val set would have the same transform:

def get_train_val_split_with_split(
        train_dataset: Dataset,
        train_val_split: list[int, int],  # e.g. [50_000, 10_000] for mnist
        batch_size: int = 128,
        batch_size_eval: int = 64,
        num_workers: int = 4,
        pin_memory: bool = False
) -> tuple[DataLoader, DataLoader]:
    """
    Note:
        - this will have the train and val sets have the same transform.

    ref:
        - https://gist.github.com/MattKleinsmith/5226a94bad5dd12ed0b871aed98cb123
    """
    train_dataset, valid_dataset = random_split(train_dataset, train_val_split)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory)
    valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                               batch_size=batch_size_eval, num_workers=num_workers,
                                               pin_memory=pin_memory)
    return train_loader, valid_loader

related: Changing transformation applied to data during training - #11 by ptrblck

Instead of replacing the internal transformations, I would rather create separate training and validation datasets by splitting the indices and and using a Subset. Something like this would work:

train_dataset = MyDataset(train_transform)
val_dataset = MyDataset(val_transform)
train_indices, val_indices = sklearn.model_selection.train_test_split(indices)
train_dataset = torch.utils.data.Subset(train_dataset, train_indices)
val_dataset = torch.utils.data.Subset(train_dataset, val_indices)
1 Like

I think you meant:

train_dataset = MyDataset(train_transform)
val_dataset = MyDataset(val_transform)
train_indices, val_indices = sklearn.model_selection.train_test_split(indices)
train_dataset = torch.utils.data.Subset(train_dataset, train_indices)
val_dataset = torch.utils.data.Subset(val_dataset, val_indices)

I think the indices can be obtained as follow:

    num_train = len(train_dataset)
    indices = list(range(num_train))
    split_idx = int(np.floor(val_size * num_train))

    train_idx, valid_idx = indices[:split_idx], indices[split_idx:]
    assert len(train_idx) != 0 and len(valid_idx) != 0
1 Like

You can also do everything you need with pytorch:

train_dataset = MyDataset(train_transform)
val_dataset = MyDataset(val_transform)
indices = torch.randperm(len(train_dataset))
val_size = len(train_dataset)//4
train_dataset = torch.utils.data.Subset(train_dataset, indices[:-val_size])
val_dataset = torch.utils.data.Subset(train_dataset, indices[-val_size:])

hey, so your code is great, but there’s a typo that makes a really weird error if you don’t fix it

train_dataset = MyDataset(train_transform)
val_dataset = MyDataset(val_transform)
indices = torch.randperm(len(train_dataset))
val_size = len(train_dataset)//4
train_dataset = torch.utils.data.Subset(train_dataset, indices[:-val_size])
val_dataset = torch.utils.data.Subset(val_dataset, indices[-val_size:])

may I know what MyDataset refers to? is it a custom class for the dataset?

Is it the Mydataset is refer to a custom class for datasets?