How to fix all workers' seed via worker_init_fn for every iter?

While the worker will use the base_seed + worker_id as its internal seed, your code should work nevertheless, since you are reseeding, if I’m not mistaken.
Here is a small example, which also shows the different worker seeds, but the same result for torch.randn:

class MyDataset(Dataset):
    def __init__(self):
        pass

    def __getitem__(self, idx):
        # this is where the results printed
        print(torch.utils.data.get_worker_info())
        return torch.randn(1)

    def __len__(self):
        return(10)

def worker_init_fn(worker_id):                                                                                                                                
    seed = 0                
                                                                                                                                   
    torch.manual_seed(seed)                                                                                                                                   
    torch.cuda.manual_seed(seed)                                                                                                                              
    torch.cuda.manual_seed_all(seed)                                                                                          
    np.random.seed(seed)                                                                                                             
    random.seed(seed)                                                                                                       
    torch.manual_seed(seed)                                                                                                                                   
    return

    
dataset = MyDataset()
dataloader = DataLoader(dataset,
                        batch_size=2,
                        shuffle=False, 
                        num_workers=8,         
                        worker_init_fn=worker_init_fn)

for data in dataloader:
    print(data)
1 Like