Combining and splitting a dataset

What is a straightforward way to split a dataset in pytorch?
Say I have cifar10 train (50,000 samples) and test (10,000 samples) sets, and I want to combine them (itertools.chain?) and split the new set into two sets with 30,000 samples each. How can I do it?
Thanks!

you can use the sampler argument of DataLoader to either use a subset of indices or a RandomSubsetSampler

2 Likes