Train Test Split using SubsetRandomSampler

https://stackoverflow.com/questions/50544730/how-do-i-split-a-custom-dataset-into-training-and-test-datasets

I am trying to split the dataset into train-test datset using the SubsetRandomSampler.

I want to achieve this:
Shuffle the train dataset and not shuffle the validation dataset.

I also need first 80% of the data for training and last 20% for validation.

DOes this makes sense?

from torch.utils.data.sampler import SubsetRandomSampler, SequentialSampler 

train_split = 0.8
random_seed = 42

dataset_size = len(dataset)
validation_split = .2
random_seed= 42
indices = list(range(dataset_size))
split = int(np.floor(validation_split * dataset_size))
train_indices, val_indices = indices[split:], indices[:split]

train_sampler = SubsetRandomSampler(train_indices)
valid_sampler =  SequentialSampler(val_indices)

train_loader = torch.utils.data.DataLoader(dataset, batch_size=64, 
                                           sampler=train_sampler)
validation_loader = torch.utils.data.DataLoader(dataset, batch_size=64,
                                                sampler=valid_sampler)

SequentialSampler expects a Dataset as the input, so you could use Subset for the validation set instead:

val_dataset = Subset(dataset, val_indices)
validation_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

Beside that the code looks alright.

1 Like

Thank you works fine now :slight_smile:

Thanks for your information, but there is a big problem when your dataset data does not benefited from random distributed data. In my case, I had a range on data which each class’s data stacked back to back. meaning that after each class I have another class. so this code just split some class data, then just chose rest of data which belongs to another classes as test data.
tried below code and it works for me to randomized data.
train_size = int(0.8 * len(full_dataset)) test_size = len(full_dataset) - train_size train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])