Random Transforms not consistent with the same seed

I was recently trying to train a resnet on ImageNet with consistent images inputs across runs, yet still with data augmentation, such as cropping, flipping rotating, etc.
I run into a problem with the fact, that there is no way of consistently getting the same random crops.
Here is a minimal example I created:

import torch
from torchvision import transforms


torch.random.manual_seed(1)

x = torch.rand((3, 10, 10))

tf_crop = transforms.Compose([
    transforms.ToPILImage(),
    transforms.RandomResizedCrop(4),
    transforms.ToTensor(),
])
tf_flip = transforms.Compose([
    transforms.ToPILImage(),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])
tf_rot = transforms.Compose([
    transforms.ToPILImage(),
    transforms.RandomRotation(45),
    transforms.ToTensor(),
])

# Consistent x among the calls
print(x[:2, :2, :2])

# RandomRotation, RandomResizedCrop and Random HorizontalFlip changes stuff
# even if seed stays the same
for idx in range(2):
    torch.random.manual_seed(1)
    print(f'Crop {idx + 1}')
    print(tf_crop(x)[:2, :2, :2].numpy())
for idx in range(2):
    torch.random.manual_seed(1)
    print(f'Flip {idx + 1}')
    print(tf_flip(x)[:2, :2, :2].numpy())
for idx in range(2):
    torch.random.manual_seed(1)
    print(f'Rotation {idx + 1}')
    print(tf_rot(x)[:2, :2, :2].numpy())

Each iteration of each of the for loops produces different results. However, at least in my understanding, across iterations, the results should be the same since the seed is set to the same number.

I would appreciate it if someone could point in direction on how to achieve consistency across different iterations of for loops (NOT across different for loops).

Some torchvision transformation use the Python random library to sample the numbers.
Add random.seed(1) to the loop and you should get the same results.

3 Likes

Thank you, that worked flawlessly!

In PyTorch 1.6, you need to use torch.manual_seed(5) and random.seed(5) same time. Please see the issue in: https://github.com/pytorch/pytorch/issues/42331

1 Like