Changing transforms after creating a dataset

I think you meant:

train_dataset = MyDataset(train_transform)
val_dataset = MyDataset(val_transform)
train_indices, val_indices = sklearn.model_selection.train_test_split(indices)
train_dataset = torch.utils.data.Subset(train_dataset, train_indices)
val_dataset = torch.utils.data.Subset(val_dataset, val_indices)

I think the indices can be obtained as follow:

    num_train = len(train_dataset)
    indices = list(range(num_train))
    split_idx = int(np.floor(val_size * num_train))

    train_idx, valid_idx = indices[:split_idx], indices[split_idx:]
    assert len(train_idx) != 0 and len(valid_idx) != 0
1 Like