Had to dig up a lot of docs and discussions to finally figure out the right set of code to make a pytorch pipeline completely reproducible.
There are two things here, firstly you would want to use a seed that will be used to seed pytroch, numpy and python right at the start of main process. Be it training, be it inference.
You would want to use this function:
def seed_all(seed):
if not seed:
seed = 10
print("[ Using Seed : ", seed, " ]")
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.cuda.manual_seed(seed)
numpy.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
Just call it right at the start and you are good to go.
Now comes the part where you use worker processes in DataLoader to speed up training. Each worker process is a separate realm and has no relation to the seed that you have used in your main process except for the fact that internally PyTorch seeds itself for you using “base_seed + worker_id” and base_seed is generated based on the seed you used in the main process. So we are still predictable until this part. Keep in mind that whether u define a worker_init_fn or not, PyTorch will get seeded.
The reason why you would want to implement the worker_init_fn is so that you can seed the other libraries that PyTorch does not seed by default. For example numpy and python random module. As you may know, Albumentations library uses python random module and thus making sure that each worker has a unique yet predictable seed is important for randomness. So you would want to use this (Use the worker_seed to seed everything else except PyTorch which has been seeded for u):
def seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2**32
numpy.random.seed(worker_seed)
random.seed(worker_seed)
you would define this function in the global scope and use it like so:
DataLoader(
train_dataset,
batch_size=batch_size,
num_workers=4,
worker_init_fn=seed_worker,
shuffle=True
)
I am training on a single GPU and this ensures AtoZ reproducibility. The only catch here is you will get different results if you change the number of workers for obvious reasons.
I hope this post reaches people on time so that they don’t waste precious GPU compute trying to figure out where on earth the randomness comes from.