I believe I fixed all workers’ seed as 0 via worker_init_fn
, but it turned out all workers have different seeds indeed.
Is there any way to fix all workers’ seed as a fixed constant for every iteration?
In dataset.py
,
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from pytorch_lightning.core.datamodule import LightningDataModule
class MyDataset(Dataset):
...
def __getitem__(self, idx):
*# this is where the results printed*
print(torch.utils.data.get_worker_info())
class MyDataModule(LightningDataModule):
...
def worker_init_fn(self, worker_id):
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
return
def val_dataloader(self):
dataloader = DataLoader(dataset, \
batch_size=self.args.bsz, \
shuffle=False, \
num_workers=8, \
worker_init_fn=self.worker_init_fn)
This code gives me workers with different seed :
WorkerInfo(id=2, num_workers=8, seed=900450186894289457, dataset=<data.MyDataset object at 0x7f11d3613110>)
WorkerInfo(id=7, num_workers=8, seed=900450186894289462, dataset=<data.MyDataset object at 0x7f11d3613110>)
WorkerInfo(id=3, num_workers=8, seed=900450186894289458, dataset=<data.MyDataset object at 0x7f11d3613110>)
WorkerInfo(id=6, num_workers=8, seed=900450186894289461, dataset=<data.MyDataset object at 0x7f11d3613110>)
WorkerInfo(id=5, num_workers=8, seed=900450186894289460, dataset=<data.MyDataset object at 0x7f11d3613110>)
WorkerInfo(id=2, num_workers=8, seed=900450186894289457, dataset=<data.MyDataset object at 0x7f11d3613110>)
WorkerInfo(id=7, num_workers=8, seed=900450186894289462, dataset=<data.MyDataset object at 0x7f11d3613110>)
WorkerInfo(id=3, num_workers=8, seed=900450186894289458, dataset=<data.MyDataset object at 0x7f11d3613110>)
WorkerInfo(id=0, num_workers=8, seed=900450186894289455, dataset=<data.MyDataset object at 0x7f11d3613110>)
Is there any way to fix all workers’ seed as a fixed constant for every iteration?
Thank you