How to split dataset into two considering fixed seed to ensure reproducibility in PyTorch?

I am working on one of my University assignments and there is one sub-task that says. Split the data in two (Train and Validation) while using a fixed seed to ensure reproducibility. I have written some code that is working fine but I want to know whether it is the correct way or not?

torch.manual_seed(0)
mnist_train, mnist_val = torch.utils.data.random_split(mnist_rest, [54000,6000])

I am working on the MNIST dataset.

This works, but it’s not my preferred method.

torch.manual_seed(0) sets the state of the global seed generator. So this will cause all subsequent calls that uses any kind of torch randomness to be reproducible. However, suppose you refactored your code later on:

torch.manual_seed(0)

# Adding another random call
x = torch.randn((3, 2))

mnist_train, mnist_val = torch.utils.data.random_split(mnist_rest, [54000,6000])

Then your train and validation set are no longer the same as before. I would recommend using pytorch’s generator for this.

gen = torch.Generator()
gen.manual_seed(0)
mnist_train, mnist_val = torch.utils.data.random_split(
    mnist_rest, [54000,6000],
    generator=gen
)

Basically we’re creating a generator for random numbers and we’re telling random_split to specifically use this generator and not the global generator.

That being said, I do want to clarify that your method does work and if you never broke those 2 lines apart you will always get the same train/val split. I just prefer the use of generators so make sure I’m precisely controlling the randomness in these calls.

1 Like