Considering the CIFAR10 data with 5 training batches, 1 test batch in different folders and 20% of each train batch for validation, and an option to perform 5 fold CV within those train batches, is it possible to make a single dataset class to handle this? How do I keep track of indices for validation set if I want to split randomly? Should I be using a default dataset class and then split as needed?
You could use sklearn.model_selection.KFold to create the folds and wrap your Dataset
into Subset
s to create the folds.
Alternatively you could also use SubsetSampler
s and pass them to your DataLoader
.
Also, skorch is a scikit-compatible wrapper for PyTorch, which might make the whole workflow easier.
1 Like