You can use built in function torch.utils.data.random_split(dataset, lengths)
.
Check docs here: https://pytorch.org/docs/stable/data.html#torch.utils.data.random_split
Check source here: https://pytorch.org/docs/stable/_modules/torch/utils/data/dataset.html#random_split
Also, you can see this nice example.