Reproducibility with checkpoints

Hi everyone :slight_smile:

I have a script that trains a CNN and I am able to reproduce the results using:

def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed) 
    # for cuda
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.enabled = False

I also save a checkpoint whenever the accuracy on the validation set increases. I do so like this:

checkpoint = {
                'run': run_count,
                'epoch': epoch_count,
                'model_state': model.state_dict(),
                'optimizer_state': optimizer.state_dict()
            }
            torch.save(checkpoint, 'path to folder')

However, when I resume training with a checkpoint and the same seed, I get different results to when I train the CNN from scratch up to the epoch I compare it to. Say for example, I train the network for 25 epochs and the best one is at epoch 15. Then, I load the checkpoint from epoch 15 and continue training. I would expect these results to be the same as the one from the first training process at epoch 16 and upwards. But they are not…

Does anyone know why this could be the case? Any help is very much appreciated!

All the best
snowe

The difference might come from e.g. the data shuffling, as you are reseeding the code in epoch 15 again.
You could try to iterate the DataLoader for 15 epochs or could alternatively seed the workers in each epoch with an “epoch seed” so that it would be easier to restore.

Hi @ptrblck, thank you for your response!

What do you mean with iterate the DataLoader for 15 epochs?

Also: when I set the seed again within my loop over the epochs and then resume training, I get the same results, you are right. However, these results then differ from the one I get when I don’t seed every epoch. And the difference is quite big. How does that make sense?

All the best
snowe

The order of the returned batches from your DataLoader would still be different, which would yield non-deterministic results. Note that this is usually not a problem, as your model should converge with different seeds.

If the DataLoader was the only source of randomness, you could use:

for _ in range(epochs_to_restore):
    for batch in loader:
        pass

However, this approach is brittle and will not work, if you had other calls into the random number generator.

Thank you @ptrblck, it does work like that! :slight_smile:

Out of curiosity… is it save to say that although without this approach the results are not exactly the same as they were when I left off, it will still yield comparable results and also reproducible because the only difference is the order of the batches? This can be seen as almost another layer of data shuffling, which a CNN should be able to handle anyways, if we aim to generalise the network?

Yes, I think as long as you shuffle the data and stick to your workflow, the final results should be comparable, i.e. the model should converge to the same final accuracy (+/- a small difference).
You might have some use cases where you need to restore exactly the same data ordering etc., which is more tricky as explained before. Usually these steps are necessary to debug some issues and your “standard” training shouldn’t depend on a specific seed or ordering of the shuffled data.

2 Likes

Hi @snowe, I’ve been looking into a similar issue.

I’m having trouble with reproducibility when a run crashes midway and is restarted (restarted run), compared to a run that executes to completion (full run).

In my case there are 3 places where random number generators (RNGs) are used during the run:

  • Parameter initialisation
  • Data loader sample ordering
  • Data augmentation (random affine transformations)

Parameter initialisation

Parameter initialisation was easy to solve as I simply set all seeds (using pytorch lightning seed_everything method) at the beginning of the run. When resuming the run, parameters are loaded from disk, so no seeding required here.

Data loader sample ordering

Data loader sample ordering was trickier as the RandomSampler uses a random seed to shuffle training data at the beginning of the epoch. I couldn’t get equivalent sample ordering for restarted and full runs when using seed_everything as there must have been RNG calls that I wasn’t aware of. In the end, I found a solution that used a custom Sampler with a seed based on “random_seed” (determines unique run) and “epoch”. Now the restarted and full runs have equivalent sample ordering for all epochs.

import torch
from torch.utils.data import Sampler
from typing import Sized

class RandomSampler(Sampler):
    def __init__(
        self,
        data_source: Sized,
        epoch: int = 0,
        random_seed: float = 0):
        super().__init__(data_source)
        self.__epoch = epoch
        self.__n_samples = len(data_source)
        self.__random_seed = random_seed

    def __iter__(self):
        # Create random number generator.
        # Seed is based on both 'random_seed' and 'epoch'. This allows for deterministic sampling
        # order for a particular 'random_seed', even if training is resumed from a checkpoint.
        seed = self.__random_seed + self.__epoch
        generator = torch.Generator()
        generator.manual_seed(seed)
        
        # Shuffle indices using the new generator.
        indices = torch.randperm(self.__n_samples, generator=generator).tolist()

        # Increment epoch.
        self.__epoch += 1

        return iter(indices)

Epoch was then passed to the sampler on resume.

random_seed = 42
epoch = ... # Loaded from checkpoint.
sampler = RandomSampler(train_ds, epoch=epoch, random_seed=random_seed)
train_loader = DataLoader(..., sampler=sampler)

Data augmentation

I haven’t yet managed to solve the data augmentation problem. Data augmentation is performed within the Dataset.__getitem__ method for my custom Dataset. Ideally, I would like to seed each transformation with a combination of “random_seed”, “epoch”, and sample index (“index” argument of __getitem__ method) to produce unique and deterministic transformations. However, I’m not sure how to get access to the epoch at this level. @ptrblck could you provide some more info on your suggestion of seeding workers with an “epoch seed” each epoch?

Thanks,
Brett