Hi @snowe, I’ve been looking into a similar issue.
I’m having trouble with reproducibility when a run crashes midway and is restarted (restarted run), compared to a run that executes to completion (full run).
In my case there are 3 places where random number generators (RNGs) are used during the run:
- Parameter initialisation
- Data loader sample ordering
- Data augmentation (random affine transformations)
Parameter initialisation
Parameter initialisation was easy to solve as I simply set all seeds (using pytorch lightning seed_everything
method) at the beginning of the run. When resuming the run, parameters are loaded from disk, so no seeding required here.
Data loader sample ordering
Data loader sample ordering was trickier as the RandomSampler
uses a random seed to shuffle training data at the beginning of the epoch. I couldn’t get equivalent sample ordering for restarted and full runs when using seed_everything
as there must have been RNG calls that I wasn’t aware of. In the end, I found a solution that used a custom Sampler
with a seed based on “random_seed” (determines unique run) and “epoch”. Now the restarted and full runs have equivalent sample ordering for all epochs.
import torch
from torch.utils.data import Sampler
from typing import Sized
class RandomSampler(Sampler):
def __init__(
self,
data_source: Sized,
epoch: int = 0,
random_seed: float = 0):
super().__init__(data_source)
self.__epoch = epoch
self.__n_samples = len(data_source)
self.__random_seed = random_seed
def __iter__(self):
# Create random number generator.
# Seed is based on both 'random_seed' and 'epoch'. This allows for deterministic sampling
# order for a particular 'random_seed', even if training is resumed from a checkpoint.
seed = self.__random_seed + self.__epoch
generator = torch.Generator()
generator.manual_seed(seed)
# Shuffle indices using the new generator.
indices = torch.randperm(self.__n_samples, generator=generator).tolist()
# Increment epoch.
self.__epoch += 1
return iter(indices)
Epoch was then passed to the sampler on resume.
random_seed = 42
epoch = ... # Loaded from checkpoint.
sampler = RandomSampler(train_ds, epoch=epoch, random_seed=random_seed)
train_loader = DataLoader(..., sampler=sampler)
Data augmentation
I haven’t yet managed to solve the data augmentation problem. Data augmentation is performed within the Dataset.__getitem__
method for my custom Dataset
. Ideally, I would like to seed each transformation with a combination of “random_seed”, “epoch”, and sample index (“index” argument of __getitem__
method) to produce unique and deterministic transformations. However, I’m not sure how to get access to the epoch
at this level. @ptrblck could you provide some more info on your suggestion of seeding workers with an “epoch seed” each epoch?
Thanks,
Brett