Cuda rng state does NOT change when re-seeding! WHY is that?

Thank you!

For now, I am not considering to move to more recent versions/nightly Pytorch since I need to update my code. But, I will definitely do that for my next code.

Now, based on your current comment and previous
one, I tried to RESET the rng state directly using a handcrafted/random rng state.

I found the following:

  1. One can reset the rng state. (expected)
  2. Any rng with the right length can do that. (824016 for CUA, and 5048 for CPU)
  3. It is easy to set an rng state that generates inf numbers. This means setting a handcrafted rng state needs to be done with care.
  4. However, it seems like that if the rng does not like the handcrafted rng state (it does not raise an error), but it seems to ignore it, and goes by its own random rng state. rng state may have some predefined structure (in term of numbers).
  5. One can control only the initial rng state. After it has been fixed, the next rng state is random (because 4).
  6. To make sure that I am using a valid rng state, I store, at the FIRST run, the rng state of CUDA and CPU on disc . In the subsequent runs, I load the saved rng states, and set them. While passing a valid rng state around allows to obtain the same rng state over the iterations, the FIRST iteration of the SUBSEQUENT runs does not obey to the idea of fixed rng state (the rng ignore!!!).

Before you read the code/runs, do you think that there is another variable other than the rng state of CUA/CPU that the RNG of CUDA depends on? you will see clearly that even when using the exact same rng state for CUDA, it does not lead to the same rng state, in a sens: rngs1 = torch.set_rng_state(stored_rngs); rngs2 = torch.get_rng_state ();, you find that rngs1 is different than rngs2. I suspect that there is something else affecting CUDA rng.

Thanks!

Here is an example (the runs are not reproducible):

import copy
import os


import torch
import numpy as np


seed = 1
# reset the seed of Numpy to control its initial rng state.
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

use_cuda = True

# set cuda rng state using Numpy rng state or any other randomly generated state.
state_cuda = np.random.randint(0, np.iinfo(np.uint8).max + 1, 824016) * 0.  # 824016 is the expected size of thr rng
# state.
# A normal rng state is not random or it has some patterns between its internal numbers!!! random rng state can easily
# lead to generating inf value. Try setting the rng state all to 1s or using the above np.random.randint() to see
# that the generator breaks at inf test.

state_cpu = np.random.randint(0, np.iinfo(np.uint8).max + 1, 5048) * 0.

l_cuda = (torch.tensor(state_cuda).type(torch.uint8))
l_cpu = (torch.tensor(state_cpu).type(torch.uint8))

# use a VALIDE stored rng state
LOAD = True
filecuda = "prng/cuda.npy"
filecpu = "prng/cpu.npy"
if LOAD:
    if not os.path.exists(filecuda) and not os.path.exists(filecpu):
        l_cuda = (torch.cuda.get_rng_state(device=0))
        torch.save(l_cuda, filecuda)
        l_cpu = (torch.get_rng_state())
        torch.save(l_cpu, filecpu)
    else:
        arrcuda = torch.load(filecuda)
        l_cuda = (torch.tensor(arrcuda).type(torch.uint8))
        arrcpu = torch.load(filecpu)
        l_cpu = (torch.tensor(arrcpu).type(torch.uint8))

for i in range(10):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # force the cuda rng state to be equal to `state`.
    torch.cuda.set_rng_state(l_cuda, device=0)
    torch.set_rng_state(l_cpu)
    lastlcuda = copy.deepcopy(l_cuda)
    lastlcpu = copy.deepcopy(l_cpu)
    l_cuda = (torch.cuda.get_rng_state(device=0))
    st = bool((l_cuda == lastlcuda).all())
    print("Iter. {}: current (rngs.sum(){}) == previous (rngs.sum(){})? {} CUDA".format(
        i, l_cuda.sum(), lastlcuda.sum(), st))
    st = bool((l_cpu == lastlcpu).all())
    print("Iter. {}: current (rngs.sum(){}) == previous (rngs.sum(){})? {} CPU".format(
        i, l_cpu.sum(), lastlcpu.sum(), st))
    a = torch.rand(1000, device="cuda:0")  # torch.bernoulli(torch.full((300, 300), 0.5, device='cuda:0'))
    print(a.sum())

run 1: (store the rng state of CUDA and CPU)

Iter. 0: current (rngs.sum()55928053) == previous (rngs.sum()55928053)? True CUDA
Iter. 0: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(498.2127, device='cuda:0')
Iter. 1: current (rngs.sum()55928053) == previous (rngs.sum()55928053)? True CUDA
Iter. 1: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(498.2127, device='cuda:0')
Iter. 2: current (rngs.sum()55928053) == previous (rngs.sum()55928053)? True CUDA
Iter. 2: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(498.2127, device='cuda:0')
Iter. 3: current (rngs.sum()55928053) == previous (rngs.sum()55928053)? True CUDA
Iter. 3: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(498.2127, device='cuda:0')
Iter. 4: current (rngs.sum()55928053) == previous (rngs.sum()55928053)? True CUDA
Iter. 4: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(498.2127, device='cuda:0')
Iter. 5: current (rngs.sum()55928053) == previous (rngs.sum()55928053)? True CUDA
Iter. 5: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(498.2127, device='cuda:0')
Iter. 6: current (rngs.sum()55928053) == previous (rngs.sum()55928053)? True CUDA
Iter. 6: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(498.2127, device='cuda:0')
Iter. 7: current (rngs.sum()55928053) == previous (rngs.sum()55928053)? True CUDA
Iter. 7: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(498.2127, device='cuda:0')
Iter. 8: current (rngs.sum()55928053) == previous (rngs.sum()55928053)? True CUDA
Iter. 8: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(498.2127, device='cuda:0')
Iter. 9: current (rngs.sum()55928053) == previous (rngs.sum()55928053)? True CUDA
Iter. 9: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU

You see that everything is perfect (the rng state is the same, including the ietr. 0).

Now, if we run a second time where we reload the rng states, see what happens:

run 2: (load the stored rng state of CUDA and CPU)

Iter. 0: current (rngs.sum()55990853) == previous (rngs.sum()55928053)? **False** CUDA
Iter. 0: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(498.2127, device='cuda:0')
Iter. 1: current (rngs.sum()55990853) == previous (rngs.sum()55990853)? True CUDA
Iter. 1: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(498.2127, device='cuda:0')
Iter. 2: current (rngs.sum()55990853) == previous (rngs.sum()55990853)? True CUDA
Iter. 2: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(498.2127, device='cuda:0')
Iter. 3: current (rngs.sum()55990853) == previous (rngs.sum()55990853)? True CUDA
Iter. 3: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(498.2127, device='cuda:0')
Iter. 4: current (rngs.sum()55990853) == previous (rngs.sum()55990853)? True CUDA
Iter. 4: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(498.2127, device='cuda:0')
Iter. 5: current (rngs.sum()55990853) == previous (rngs.sum()55990853)? True CUDA
Iter. 5: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(498.2127, device='cuda:0')
Iter. 6: current (rngs.sum()55990853) == previous (rngs.sum()55990853)? True CUDA
Iter. 6: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(498.2127, device='cuda:0')
Iter. 7: current (rngs.sum()55990853) == previous (rngs.sum()55990853)? True CUDA
Iter. 7: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(498.2127, device='cuda:0')
Iter. 8: current (rngs.sum()55990853) == previous (rngs.sum()55990853)? True CUDA
Iter. 8: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(498.2127, device='cuda:0')
Iter. 9: current (rngs.sum()55990853) == previous (rngs.sum()55990853)? True CUDA
Iter. 9: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(498.2127, device='cuda:0')

At iter. 0, the loaded rng state for CUDA was not considered!!! in other words:

torch.cuda.set_rng_state(l_cuda, device=0)
rngs = (torch.cuda.get_rng_state(device=0))

leads to rngs different than l_cuda!!! any insight why?

If you wondering why I am trying to fix this and obtain the same rng state, it is because I am trying to make Pytorch reproducible over MultiGPUs. For that, I use:

  1. threading.Lock() to lock the zones that cause randomness so I can reseed each thread safely since they all share the same rng state. (caused by using threads in torch.nn.DataParallel who happen to share rng states between threads and mess things up).
  2. Re-seeding, resetting rngs state for each thread (in a thread-safe manner) to control the randomness of each thread.

I will post the code here, as soon as I obtain reproducible results on a real code.

For now, on a synthetic code (using simple instructions), I seem to be able to achieve reproducibility over multiGPU for the simple code. However, when I tried the idea on real code, it does not work ALL the time.

I followed your comment to pass rng states around instead of just re-seeding. For now, there is no luck. Sometimes, I obtain EXACTLY the same result over my real code. But, most of the time reproducibility fails. There, after some checking, I realized that there is something wrong with the rng states. Outside the threads (call of parallelized previous forward), rng state of the main thread is not deterministic eventhough I force it to be (in somehow, it does not care).

===================================================

                           NUMPY

===================================================
See how Numpy rng state can easily be controlled using only a seed.

Code:

import copy

import numpy as np

seed = 1
use_cuda = True
np.random.seed(seed)
l = np.random.get_state()
for i in range(10):
    np.random.seed(seed)
    lastl = copy.deepcopy(l)
    l = (np.random.get_state())
    print("Iter. {}: current (rngs.sum(){}) == previous (rngs.sum(){})? {}".format(
        i, l[1].sum(), lastl[1].sum(), bool((l[1] == lastl[1]).all())))
    a = np.random.uniform(1)
    print(a)

run 1

Iter. 0: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 1: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 2: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 3: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 4: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 5: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 6: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 7: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 8: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 9: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0

run 2

Iter. 0: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 1: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 2: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 3: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 4: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 5: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 6: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 7: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 8: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 9: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0

run 3

Iter. 0: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 1: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 2: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 3: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 4: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 5: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 6: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 7: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 8: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 9: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0

Total reproducibility.

Sorry for long message.

Thanks!