Startified Sampling for multiclass

I am quite new to pytorch so please bear with me here. I have a CSV dataset which looks like so:

class_label,image_location
1, /some/loc0
2, /some/loc1
0, /some/loc2
1 /some/loc4

where the class_label is my target and theimage_location is my input to NN.

I would like to use a dataloader to somehow split this into train and test sets, with a stratified sampling of each class in the CSV (20 classes, 4 images per class in the CSV).

When googling around, I see some pandas and scikit-learn solutions, so, beginning with something like:

sss = StratifiedShuffleSplit(df['event'], n_iter=1, test_size=0.2)

which I presume gives you a single train and test split which I can use with dataSet later batch with DataLoader. However, I am not sure if this is an elegant solution and I was wondering if someone could point me in the right direction to schieve this.