If anyone’s stuck on creating train, validation and test splits for datasets in torchvision.datasets, I’ve created a small gist that supports transformations, shuffling, seeding, and optional plotting.
The main logic of the code is as follows:
- Generate a train dataset and test dataset using the argument
train
intorchvision.datasets.XYZ
whereXYZ
is your desired data (i.e. CIFAR10 or ImageNet). - Figure out the length of your validation set
num_valid
. If say you want 10% of your training data to be used for validation, then you would multiply the total length of the training set by 0.1. - Create a list of indices of size
num_train
, shuffle it, and then slicenum_valid
indices from it and call itvalid_idx
and store the rest intrain_idx
. - Feed these indices to separate instances of
SubsetRandomSampler
. - Finally feed these 2 samplers to
torch.utils.data.DataLoader
using thesampler
argument.
And voila!
Note that the code is inspired by Mamy Ratsimbazafy’s code which you can view here.