Thank you!
For now, I am not considering to move to more recent versions/nightly Pytorch since I need to update my code. But, I will definitely do that for my next code.
Now, based on your current comment and previous
one, I tried to RESET the rng state directly using a handcrafted/random rng state.
I found the following:
- One can reset the rng state. (expected)
- Any rng with the right length can do that. (824016 for CUA, and 5048 for CPU)
- It is easy to set an rng state that generates
inf
numbers. This means setting a handcrafted rng state needs to be done with care. - However, it seems like that if the rng does not
like
the handcrafted rng state (it does not raise an error), but it seems to ignore it, and goes by its own random rng state. rng state may have some predefined structure (in term of numbers). - One can control only the initial rng state. After it has been fixed, the next rng state is random (because 4).
- To make sure that I am using a
valid
rng state, I store, at the FIRST run, the rng state of CUDA and CPU on disc . In the subsequent runs, I load the saved rng states, and set them. While passing avalid
rng state around allows to obtain the same rng state over the iterations, the FIRST iteration of the SUBSEQUENT runs does not obey to the idea of fixed rng state (the rng ignore!!!).
Before you read the code/runs, do you think that there is another variable other than the rng state of CUA/CPU that the RNG of CUDA depends on? you will see clearly that even when using the exact same rng state for CUDA, it does not lead to the same rng state, in a sens: rngs1 = torch.set_rng_state(stored_rngs); rngs2 = torch.get_rng_state ();
, you find that rngs1
is different than rngs2
. I suspect that there is something else affecting CUDA rng.
Thanks!
Here is an example (the runs are not reproducible):
import copy
import os
import torch
import numpy as np
seed = 1
# reset the seed of Numpy to control its initial rng state.
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
use_cuda = True
# set cuda rng state using Numpy rng state or any other randomly generated state.
state_cuda = np.random.randint(0, np.iinfo(np.uint8).max + 1, 824016) * 0. # 824016 is the expected size of thr rng
# state.
# A normal rng state is not random or it has some patterns between its internal numbers!!! random rng state can easily
# lead to generating inf value. Try setting the rng state all to 1s or using the above np.random.randint() to see
# that the generator breaks at inf test.
state_cpu = np.random.randint(0, np.iinfo(np.uint8).max + 1, 5048) * 0.
l_cuda = (torch.tensor(state_cuda).type(torch.uint8))
l_cpu = (torch.tensor(state_cpu).type(torch.uint8))
# use a VALIDE stored rng state
LOAD = True
filecuda = "prng/cuda.npy"
filecpu = "prng/cpu.npy"
if LOAD:
if not os.path.exists(filecuda) and not os.path.exists(filecpu):
l_cuda = (torch.cuda.get_rng_state(device=0))
torch.save(l_cuda, filecuda)
l_cpu = (torch.get_rng_state())
torch.save(l_cpu, filecpu)
else:
arrcuda = torch.load(filecuda)
l_cuda = (torch.tensor(arrcuda).type(torch.uint8))
arrcpu = torch.load(filecpu)
l_cpu = (torch.tensor(arrcpu).type(torch.uint8))
for i in range(10):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# force the cuda rng state to be equal to `state`.
torch.cuda.set_rng_state(l_cuda, device=0)
torch.set_rng_state(l_cpu)
lastlcuda = copy.deepcopy(l_cuda)
lastlcpu = copy.deepcopy(l_cpu)
l_cuda = (torch.cuda.get_rng_state(device=0))
st = bool((l_cuda == lastlcuda).all())
print("Iter. {}: current (rngs.sum(){}) == previous (rngs.sum(){})? {} CUDA".format(
i, l_cuda.sum(), lastlcuda.sum(), st))
st = bool((l_cpu == lastlcpu).all())
print("Iter. {}: current (rngs.sum(){}) == previous (rngs.sum(){})? {} CPU".format(
i, l_cpu.sum(), lastlcpu.sum(), st))
a = torch.rand(1000, device="cuda:0") # torch.bernoulli(torch.full((300, 300), 0.5, device='cuda:0'))
print(a.sum())
run 1: (store the rng state of CUDA and CPU)
Iter. 0: current (rngs.sum()55928053) == previous (rngs.sum()55928053)? True CUDA
Iter. 0: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(498.2127, device='cuda:0')
Iter. 1: current (rngs.sum()55928053) == previous (rngs.sum()55928053)? True CUDA
Iter. 1: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(498.2127, device='cuda:0')
Iter. 2: current (rngs.sum()55928053) == previous (rngs.sum()55928053)? True CUDA
Iter. 2: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(498.2127, device='cuda:0')
Iter. 3: current (rngs.sum()55928053) == previous (rngs.sum()55928053)? True CUDA
Iter. 3: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(498.2127, device='cuda:0')
Iter. 4: current (rngs.sum()55928053) == previous (rngs.sum()55928053)? True CUDA
Iter. 4: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(498.2127, device='cuda:0')
Iter. 5: current (rngs.sum()55928053) == previous (rngs.sum()55928053)? True CUDA
Iter. 5: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(498.2127, device='cuda:0')
Iter. 6: current (rngs.sum()55928053) == previous (rngs.sum()55928053)? True CUDA
Iter. 6: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(498.2127, device='cuda:0')
Iter. 7: current (rngs.sum()55928053) == previous (rngs.sum()55928053)? True CUDA
Iter. 7: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(498.2127, device='cuda:0')
Iter. 8: current (rngs.sum()55928053) == previous (rngs.sum()55928053)? True CUDA
Iter. 8: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(498.2127, device='cuda:0')
Iter. 9: current (rngs.sum()55928053) == previous (rngs.sum()55928053)? True CUDA
Iter. 9: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
You see that everything is perfect (the rng state is the same, including the ietr. 0).
Now, if we run a second time where we reload the rng states, see what happens:
run 2: (load the stored rng state of CUDA and CPU)
Iter. 0: current (rngs.sum()55990853) == previous (rngs.sum()55928053)? **False** CUDA
Iter. 0: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(498.2127, device='cuda:0')
Iter. 1: current (rngs.sum()55990853) == previous (rngs.sum()55990853)? True CUDA
Iter. 1: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(498.2127, device='cuda:0')
Iter. 2: current (rngs.sum()55990853) == previous (rngs.sum()55990853)? True CUDA
Iter. 2: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(498.2127, device='cuda:0')
Iter. 3: current (rngs.sum()55990853) == previous (rngs.sum()55990853)? True CUDA
Iter. 3: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(498.2127, device='cuda:0')
Iter. 4: current (rngs.sum()55990853) == previous (rngs.sum()55990853)? True CUDA
Iter. 4: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(498.2127, device='cuda:0')
Iter. 5: current (rngs.sum()55990853) == previous (rngs.sum()55990853)? True CUDA
Iter. 5: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(498.2127, device='cuda:0')
Iter. 6: current (rngs.sum()55990853) == previous (rngs.sum()55990853)? True CUDA
Iter. 6: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(498.2127, device='cuda:0')
Iter. 7: current (rngs.sum()55990853) == previous (rngs.sum()55990853)? True CUDA
Iter. 7: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(498.2127, device='cuda:0')
Iter. 8: current (rngs.sum()55990853) == previous (rngs.sum()55990853)? True CUDA
Iter. 8: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(498.2127, device='cuda:0')
Iter. 9: current (rngs.sum()55990853) == previous (rngs.sum()55990853)? True CUDA
Iter. 9: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(498.2127, device='cuda:0')
At iter. 0, the loaded rng state for CUDA was not considered!!! in other words:
torch.cuda.set_rng_state(l_cuda, device=0)
rngs = (torch.cuda.get_rng_state(device=0))
leads to rngs
different than l_cuda
!!! any insight why?
If you wondering why I am trying to fix this and obtain the same rng state, it is because I am trying to make Pytorch reproducible over MultiGPUs. For that, I use:
-
threading.Lock()
to lock the zones that cause randomness so I can reseed each thread safely since they all share the same rng state. (caused by using threads intorch.nn.DataParallel
who happen to share rng states between threads and mess things up). - Re-seeding, resetting rngs state for each thread (in a thread-safe manner) to control the randomness of each thread.
I will post the code here, as soon as I obtain reproducible results on a real code.
For now, on a synthetic code (using simple instructions), I seem to be able to achieve reproducibility over multiGPU for the simple code. However, when I tried the idea on real
code, it does not work ALL the time.
I followed your comment to pass rng states around instead of just re-seeding. For now, there is no luck. Sometimes, I obtain EXACTLY the same result over my real code. But, most of the time reproducibility fails. There, after some checking, I realized that there is something wrong with the rng states. Outside the threads (call of parallelized previous forward
), rng state of the main thread is not deterministic eventhough I force it to be (in somehow, it does not care).
===================================================
NUMPY
===================================================
See how Numpy rng state can easily be controlled using only a seed.
Code:
import copy
import numpy as np
seed = 1
use_cuda = True
np.random.seed(seed)
l = np.random.get_state()
for i in range(10):
np.random.seed(seed)
lastl = copy.deepcopy(l)
l = (np.random.get_state())
print("Iter. {}: current (rngs.sum(){}) == previous (rngs.sum(){})? {}".format(
i, l[1].sum(), lastl[1].sum(), bool((l[1] == lastl[1]).all())))
a = np.random.uniform(1)
print(a)
run 1
Iter. 0: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 1: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 2: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 3: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 4: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 5: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 6: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 7: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 8: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 9: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
run 2
Iter. 0: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 1: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 2: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 3: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 4: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 5: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 6: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 7: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 8: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 9: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
run 3
Iter. 0: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 1: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 2: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 3: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 4: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 5: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 6: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 7: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 8: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 9: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Total reproducibility.
Sorry for long message.
Thanks!