How to use sklearn's train_test_split on PyTorch's dataset

Have a look at @kevinzakka’s approach here. It might give you a good starter code for your implementation.
Since you apparently would like to split your CIFAR10 dataset in a stratified fashion, you could use the internal targets to achieve that:

targets = dataset.targets

train_idx, valid_idx= train_test_split(
    np.arange(len(targets)), test_size=0.2, random_state=42, shuffle=True, stratify=targets)

print(np.unique(np.array(targets)[train_idx], return_counts=True))
print(np.unique(np.array(targets)[valid_idx], return_counts=True))

These indices can then be used for the SubsetRandomSampler.

3 Likes