How to fix all workers' seed via worker_init_fn for every iter?

I believe I fixed all workers’ seed as 0 via worker_init_fn, but it turned out all workers have different seeds indeed.

Is there any way to fix all workers’ seed as a fixed constant for every iteration?

In dataset.py,

from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from pytorch_lightning.core.datamodule import LightningDataModule

class MyDataset(Dataset):
   ...
    def __getitem__(self, idx):
        *# this is where the results printed*
        print(torch.utils.data.get_worker_info())

class MyDataModule(LightningDataModule):
   ...
    def worker_init_fn(self, worker_id):                                                                                                                                
        seed = 0                                                                                                                                                  
        torch.manual_seed(seed)                                                                                                                                   
        torch.cuda.manual_seed(seed)                                                                                                                              
        torch.cuda.manual_seed_all(seed)                                                                                          
        np.random.seed(seed)                                                                                                             
        random.seed(seed)                                                                                                       
        torch.manual_seed(seed)                                                                                                                                   
        return

    def val_dataloader(self):
        dataloader = DataLoader(dataset, \
                             batch_size=self.args.bsz, \
                             shuffle=False, \
                             num_workers=8, \        
                             worker_init_fn=self.worker_init_fn)

This code gives me workers with different seed :confused: :

WorkerInfo(id=2, num_workers=8, seed=900450186894289457, dataset=<data.MyDataset object at 0x7f11d3613110>)
WorkerInfo(id=7, num_workers=8, seed=900450186894289462, dataset=<data.MyDataset object at 0x7f11d3613110>)
WorkerInfo(id=3, num_workers=8, seed=900450186894289458, dataset=<data.MyDataset object at 0x7f11d3613110>)
WorkerInfo(id=6, num_workers=8, seed=900450186894289461, dataset=<data.MyDataset object at 0x7f11d3613110>)
WorkerInfo(id=5, num_workers=8, seed=900450186894289460, dataset=<data.MyDataset object at 0x7f11d3613110>)
WorkerInfo(id=2, num_workers=8, seed=900450186894289457, dataset=<data.MyDataset object at 0x7f11d3613110>)
WorkerInfo(id=7, num_workers=8, seed=900450186894289462, dataset=<data.MyDataset object at 0x7f11d3613110>)
WorkerInfo(id=3, num_workers=8, seed=900450186894289458, dataset=<data.MyDataset object at 0x7f11d3613110>)
WorkerInfo(id=0, num_workers=8, seed=900450186894289455, dataset=<data.MyDataset object at 0x7f11d3613110>)

Is there any way to fix all workers’ seed as a fixed constant for every iteration?

Thank you :smiley:

While the worker will use the base_seed + worker_id as its internal seed, your code should work nevertheless, since you are reseeding, if I’m not mistaken.
Here is a small example, which also shows the different worker seeds, but the same result for torch.randn:

class MyDataset(Dataset):
    def __init__(self):
        pass

    def __getitem__(self, idx):
        # this is where the results printed
        print(torch.utils.data.get_worker_info())
        return torch.randn(1)

    def __len__(self):
        return(10)

def worker_init_fn(worker_id):                                                                                                                                
    seed = 0                
                                                                                                                                   
    torch.manual_seed(seed)                                                                                                                                   
    torch.cuda.manual_seed(seed)                                                                                                                              
    torch.cuda.manual_seed_all(seed)                                                                                          
    np.random.seed(seed)                                                                                                             
    random.seed(seed)                                                                                                       
    torch.manual_seed(seed)                                                                                                                                   
    return

    
dataset = MyDataset()
dataloader = DataLoader(dataset,
                        batch_size=2,
                        shuffle=False, 
                        num_workers=8,         
                        worker_init_fn=worker_init_fn)

for data in dataloader:
    print(data)
1 Like

It shows the same results for every run! Thank you. :smiley:

image
But it still shows different seeds for different workers, which I believe all of them should have the same seed as I did so in worker_init_fn?

Thank you.

That is because the worker_info is created (with the seed passed in from the DataLoader) before worker_init_fn is called: