Save random generator states and resume training for DDP

I’m running DDP training on a cluster with time limit. When the time limit hits, I have to checkpoint model, optimizer etc.'s states, and resubmit a job by loading these states.

Other than these regular states, if I want the training curve to be exactly the same as if there were no time limit. I have to save random generator’s states as well. I tried the following (the actual code differs, but all key components are included below):


# Each rank does the following
rng_state_dict = {
    'cpu_rng_state': torch.get_rng_state(),
    'gpu_rng_state': torch.cuda.get_rng_state(),
    'numpy_rng_state': numpy.random.get_state(),
    'py_rng_state': random.getstate()
}, f'rng_state_{rank}.ckpt')


# Assume already knows its local_rank and (global) rank
# At the very beginning of each rank does the following
rng_state_dict = torch.load(f'rng_state_{rank}.ckpt', map_location='cpu')

I use the above strategy to save and resubmit job to continue training. However, the training curve differs from the case where there is no time limit. I wonder what piece is missing in my code?

Are you using a PyTorch DataLoader? If so, are you using shuffle=True, and are there any usages of the RNG state before the training loop begins?

My hypothesis is as follows:
If you save the RNG state s_t at the end of an epoch, then if you did not checkpoint and continued training in the next epoch, the DataLoader with shuffle=True consumes the next RNG state s_{t+1}. Now, if you did checkpoint and reload, if you have any usages of the RNG state before the DataLoader consumes its RNG state, then that RNG state for the DataLoader will no longer be s_{t+1} but rather s_{t+k} for k > 1.

If you do not want to trace every possible internal usage of RNG state, can you try setting the seed via torch.manual_seed(epoch) at the beginning of each epoch before iterating over the data loader? Even if my hypothesis is wrong, doing this may still fix your issue. Let me know if you still see differing training curves.

Thank you for your helpful thoughts! I’m using torch’s dataLoader with shuffle=True, and call dataLoader.set_epoch(epoch) at beginning of each epoch. I believe in dataLoader, there is a “local” random number generator that is fully controlled by seed=base_seed+epoch, see line 98-100 of pytorch’s
So anything consuming the “global” rng states outside dataLoader shouldn’t matter?