Issues with torch.utils.data.random_split

For your use case I would probably use Subsets and pass the indices explicitly as seen in this example as it would allow you to keep the specified transformations.