Discussion about datasets and dataloaders

If anyone’s stuck on creating train, validation and test splits for datasets in torchvision.datasets, I’ve created a small gist that supports transformations, shuffling, seeding, and optional plotting.

The main logic of the code is as follows:

  • Generate a train dataset and test dataset using the argument train in torchvision.datasets.XYZ where XYZ is your desired data (i.e. CIFAR10 or ImageNet).
  • Figure out the length of your validation set num_valid. If say you want 10% of your training data to be used for validation, then you would multiply the total length of the training set by 0.1.
  • Create a list of indices of size num_train, shuffle it, and then slice num_valid indices from it and call it valid_idx and store the rest in train_idx.
  • Feed these indices to separate instances of SubsetRandomSampler.
  • Finally feed these 2 samplers to torch.utils.data.DataLoader using the sampler argument.

And voila!

Note that the code is inspired by Mamy Ratsimbazafy’s code which you can view here.

3 Likes