You could either use torch.utils.data.random_split
and provide the lengths you would like to use for the random splitting or alternatively, if you want to use a stratified split, you could use sklearn
's train_test_split
as seen here.
1 Like