Is there a way to fix the random seed of every workers in dataloader?

transform = transforms.Compose([
transforms.ColorJitter(0.1, 0.3, 0.3, 0.05),
transforms.RandomHorizontalFlip(),
my_folder.RandomRotate(),
my_folder.RandomVerticalFlip(),
transforms.ToTensor(),
Cutout(n_holes = args.n_holes, length=args.cutout_length)
])

train_loader = DataLoader(dataset=train_dataset, shuffle=False,
batch_size=args.batch_size, num_workers=24)

RANDOM_SEED = random.randint(0,1e+5)

torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed_all(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
random.seed(RANDOM_SEED)

for batch_idx, (data, _, _,) in enumerate(train_loader) :
x1 = data
print(x1[0])
break

torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed_all(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
random.seed(RANDOM_SEED)

for batch_idx, (data, _, _,) in enumerate(train_loader) :
x2 = data
print(x2[0])
break

I’m trying to make some tricky networks, and I need to get exactly the same data with same order twice. x1 and x2 become exactly the same if I set num_workers=0. but otherwise, It isn’t. (But using single worker is not an option for me.)

Is there a way to fix the random seed of whole workers?
Thanks.

1 Like

You could use the worker_init_fn from DataLoader to set the worker’s seed.

2 Likes

% It worked. Thanks very much!

Sorry, my mistakes. it still not working.
I tried all methods suggested in here(https://github.com/pytorch/pytorch/issues/7068), but nothing worked.

import torch
torch.manual_seed(0)

An even dumber question, is it enough to set seed just in the main file or every file that is being imported in the main file should have their own seed?

Hmmm, doesn’t this solution imply the same seed will be used at every epoch?

Yes, the worker seed would be the same, but this would also be the current behavior.
In these lines of code the seed is set as the base_seed + i, where i is the worker id.
Inside the worker, the seed will be used here.

Note that this would not force the same ordering of the data, since the sampler won’t use the same seed.
The worker seed should be used for the current worker process, i.e. you could use the worker_init_fn to set the seed for all libraries you are using inside e.g. Dataset.__getitem__, such as numpy etc.

1 Like

Interesting. Thanks for your reply. Doesn’t this mean that the worker will always perform the same sequence of random transforms, each time on different data?

To be clear with what I mean:

the sampler won’t use the same seed

means that the sample distribution across workers is randomised.

But at the same time:

the worker seed would be the same

means that any random transformations performed by the worker will be the same.

I guess this is not problematic, since workers will rarely be used with a very low number of samples to augment, but it is unexpected.

No, that shouldn’t be the case and sorry for the “half-answer”.

The base_seed will be recreated in each call to BaseDataLoaderIter.__init__ here, which would be called in each epoch. Once the DataLoader is empty, it will be recreated.

Each worker will thus get a new seed and would use it for the current epoch, such that the random transformation would be pseudo-random again.

You could however use the worker seed to force the same transformation as seen in this small code snippet:

class MyDataset(Dataset):
    def __init__(self):
        self.data = torch.arange(4).float().view(4, 1, 1, 1).repeat(1, 3, 224, 224)

    def __getitem__(self, index):
        x = self.data[index]
        seed = torch.utils.data.get_worker_info().seed
        print('worker seed {}'.format(seed))
        random.seed(seed)
        # apply random transform
        i, j, h, w = transforms.RandomCrop.get_params(
            x, output_size=(200, 200))
        print('sampled crop indices {}'.format((i, j, h, w)))
        return x

    def __len__(self):
        return len(self.data)


if __name__=='__main__':

    dataset = MyDataset()

    loader = DataLoader(dataset, batch_size=2, num_workers=2)
    for epoch in range(2):
        print('epoch {}'.format(epoch))
        for data in loader:
            pass

> epoch 0
worker seed 7627862169094115211
sampled crop indices (16, 18, 200, 200)
worker seed 7627862169094115211
sampled crop indices (16, 18, 200, 200)
worker seed 7627862169094115212
sampled crop indices (11, 17, 200, 200)
worker seed 7627862169094115212
sampled crop indices (11, 17, 200, 200)
epoch 1
worker seed 3323761461075438105
sampled crop indices (23, 23, 200, 200)
worker seed 3323761461075438105
sampled crop indices (23, 23, 200, 200)
worker seed 3323761461075438104
sampled crop indices (4, 2, 200, 200)
worker seed 3323761461075438104
sampled crop indices (4, 2, 200, 200)

As you can see, each worker will sample its own random numbers in the current epoch.

Of course you could also skip the base seed completely by providing your own login in the worker_init_fn.

Let me know, if that clears things up.

1 Like

Thanks for your answer, now I think it’s clear.

The only thing bugging me now is that then, given how the seed is used

random.seed(seed)
torch.manual_seed(seed)

the only sources of randomness for reproducible data augmentation should be random and torch. Right?

Otherwise one would need to set a seed with worker_init_fn, which would get us back to the pseudo-random samples repeating for each epoch.

Thanks for the clear explanation!

Four years later, I was wondering whether it’s still managed the same way. I’ve been trying to find the source code location, where a worker’s seed is set to the base seed (resampled at each epoch/when the iterator is regenerated), plus the worker id.

Here: pytorch/torch/utils/data/dataloader.py at b539c61631fdda520d46cf653f1c30d3c332cf2c · pytorch/pytorch · GitHub
it seems like the worker_loop is called only with the base_seed as seed.

Confusingly, running your code snippet, the workers do have different seeds (so it should be implemented the way you described it), but using the same seed twice (once per sample) still gives different crop indices (so your example doesn’t seem to work on my machine):

epoch 0
worker seed 5326672603560951676
sampled crop indices (2, 12, 200, 200)
worker seed 5326672603560951677
worker seed 5326672603560951676
sampled crop indices (9, 19, 200, 200)
sampled crop indices (21, 14, 200, 200)
worker seed 5326672603560951677
sampled crop indices (22, 22, 200, 200)

epoch 1
worker seed 2890333324556720703
sampled crop indices (4, 17, 200, 200)
worker seed 2890333324556720703
sampled crop indices (22, 19, 200, 200)
worker seed 2890333324556720704
sampled crop indices (22, 11, 200, 200)
worker seed 2890333324556720704
sampled crop indices (4, 13, 200, 200)

My PyTorch versions:
torch==2.1.0+cu118
torchvision==0.16.0+cu118