Recommended pytorch way for cross validation with dataloaders and subsets

What’s the recommended way in pytorch to preform cross validation. So far I’ve seen two ways, one is to use the RandomSubsetSampler to make indices and pass those along to the DataLoader and the other one is to create your own method which makes indices for each fold, make your custom DataSet which returns train/valid loader and pass that to DataLoader.

I guess both are valid approaches but what I’m asking is what’s the native pytorch recommended way to avoid errors and be consistent with best practices. (side note: I’ve read pretty much all posts I could find relative to the subject but couldn’t find a clarification on this topic)

I’ll try to answer my own question in case someone else finds it useful. One way to achieve this (althought I’m not 100% if this is the correct pytorch way) is to use Subset class to retrieve the corresponding data and targets according to a set of indices.

For instance:

for tr_idx, te_idx in kfold(...):
     train_subset = torch.utils.data.Subset(train_dataset, tr_idx)
     train_loader = torch.utils.data.DataLoader(train_subset, batch_size=64)
     for x, y in train_loader:
          model.train()
          y_hat = model(x)
1 Like