Issues with torch.utils.data.random_split

It splits the data randomly. If you want to apply a stratified split, you could use sklearn.model_selection.train_test_split and provide the stratify argument to create the training and validation indices, which can then be used in a Subset or RandomSubsetSampler.

3 Likes