Cuda rng state does NOT change when re-seeding! WHY is that?

cuda rng state does NOT change when re-seeding!!! WHY is that?

  • Consider 1 GPU only.
  • Pytorch 1.0.0. Python 3.7.0
  • Related.

Expected behavior:

The rng state change in a deterministic way with respect to the seed.

Reality:
rng state does not care about your seed. It behaves on its own way.

Code:

import copy


import torch

seed = 1
use_cuda = True
l = (torch.cuda.get_rng_state())
# np.savetxt("rng_state_init.txt".format(0), l.numpy())
for i in range(10):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    lastl = copy.deepcopy(l)
    l = (torch.cuda.get_rng_state())
    print((l == lastl).all())
    a = torch.rand(1, device="cuda")  # torch.bernoulli(torch.full((300, 300), 0.5, device='cuda:0'))
    print(a)

Issues of the above code:

  • The initial state of the rng is random. Each run will start with a different rng state. This is fine. You can check this by storing the rng state on disc to compare the different runs. Reseeding before the loop is not problem since reseeding does not change the rng state.
  • The main issue is that after reseeding the rng with the same seed, the rng state will be random!!!. To check this, one can run the above code many times, and check the value of print((l == lastl).all()). What we expect is that after the first iteration, we find True (0) ALWAYS because reseeding is expected to reset the rng state to a state that corresponds to the seed. But, you will find from time to time False (1).
  • Yes, a is always the same. This is expected.

Now, my question is: why resetting the seed does not change the rng state? (as in Numpy, for instance).

Here are some runs:

run 1:

tensor(0, dtype=torch.uint8)
tensor(0.2921, device='cuda:0')
tensor(0, dtype=torch.uint8)
tensor(0.2921, device='cuda:0')
tensor(1, dtype=torch.uint8)
tensor(0.2921, device='cuda:0')
tensor(1, dtype=torch.uint8)
tensor(0.2921, device='cuda:0')
tensor(1, dtype=torch.uint8)
tensor(0.2921, device='cuda:0')
tensor(1, dtype=torch.uint8)
tensor(0.2921, device='cuda:0')
tensor(1, dtype=torch.uint8)
tensor(0.2921, device='cuda:0')
tensor(1, dtype=torch.uint8)
tensor(0.2921, device='cuda:0')
tensor(1, dtype=torch.uint8)
tensor(0.2921, device='cuda:0')
tensor(1, dtype=torch.uint8)
tensor(0.2921, device='cuda:0')

run 2:

tensor(0, dtype=torch.uint8)
tensor([0.2921], device='cuda:0')
tensor(0, dtype=torch.uint8)
tensor([0.2921], device='cuda:0')
tensor(0, dtype=torch.uint8)
tensor([0.2921], device='cuda:0')
tensor(0, dtype=torch.uint8)
tensor([0.2921], device='cuda:0')
tensor(0, dtype=torch.uint8)
tensor([0.2921], device='cuda:0')
tensor(1, dtype=torch.uint8)
tensor([0.2921], device='cuda:0')
tensor(1, dtype=torch.uint8)
tensor([0.2921], device='cuda:0')
tensor(0, dtype=torch.uint8)
tensor([0.2921], device='cuda:0')
tensor(0, dtype=torch.uint8)
tensor([0.2921], device='cuda:0')
tensor(0, dtype=torch.uint8)
tensor([0.2921], device='cuda:0')

run 3:

tensor(0, dtype=torch.uint8)
tensor([0.2921], device='cuda:0')
tensor(0, dtype=torch.uint8)
tensor([0.2921], device='cuda:0')
tensor(0, dtype=torch.uint8)
tensor([0.2921], device='cuda:0')
tensor(1, dtype=torch.uint8)
tensor([0.2921], device='cuda:0')
tensor(0, dtype=torch.uint8)
tensor([0.2921], device='cuda:0')
tensor(0, dtype=torch.uint8)
tensor([0.2921], device='cuda:0')
tensor(0, dtype=torch.uint8)
tensor([0.2921], device='cuda:0')
tensor(1, dtype=torch.uint8)
tensor([0.2921], device='cuda:0')
tensor(0, dtype=torch.uint8)
tensor([0.2921], device='cuda:0')
tensor(0, dtype=torch.uint8)
tensor([0.2921], device='cuda:0')

run 4:

tensor(0, dtype=torch.uint8)
tensor([0.2921], device='cuda:0')
tensor(1, dtype=torch.uint8)
tensor([0.2921], device='cuda:0')
tensor(1, dtype=torch.uint8)
tensor([0.2921], device='cuda:0')
tensor(1, dtype=torch.uint8)
tensor([0.2921], device='cuda:0')
tensor(1, dtype=torch.uint8)
tensor([0.2921], device='cuda:0')
tensor(1, dtype=torch.uint8)
tensor([0.2921], device='cuda:0')
tensor(1, dtype=torch.uint8)
tensor([0.2921], device='cuda:0')
tensor(1, dtype=torch.uint8)
tensor([0.2921], device='cuda:0')
tensor(1, dtype=torch.uint8)
tensor([0.2921], device='cuda:0')
tensor(1, dtype=torch.uint8)
tensor([0.2921], device='cuda:0')

You see how random it is!!!

+++++++++++++++++++++++++++++++++++++++++++++++++

                   NUMPY

+++++++++++++++++++++++++++++++++++++++++++++++++

Code:

import copy


import numpy as np

seed = 1
use_cuda = True
np.random.seed(0)
l = np.random.get_state()
# np.savetxt("{}.txt".format(0), l[1])
for i in range(10):
    np.random.seed(seed)
    lastl = copy.deepcopy(l)
    l = (np.random.get_state())
    print((l[1] == lastl[1]).all())
    a = np.random.uniform(1)
    print(a)

run 1:

False
1.0
True
1.0
True
1.0
True
1.0
True
1.0
True
1.0
True
1.0
True
1.0
True
1.0
True
1.0

run 2:

False
1.0
True
1.0
True
1.0
True
1.0
True
1.0
True
1.0
True
1.0
True
1.0
True
1.0
True
1.0

run 3:

False
1.0
True
1.0
True
1.0
True
1.0
True
1.0
True
1.0
True
1.0
True
1.0
True
1.0
True
1.0

run 4:

False
1.0
True
1.0
True
1.0
True
1.0
True
1.0
True
1.0
True
1.0
True
1.0
True
1.0
True
1.0

run 5:

False
1.0
True
1.0
True
1.0
True
1.0
True
1.0
True
1.0
True
1.0
True
1.0
True
1.0
True
1.0

See how deterministic Numpy is!!!

Thanks!

So until very recently, PyTorch used two CUDA RNGs the MTGP32 and the Philox 4x32 10.
My impression is that what you are getting different state values for the former - I must admit I’m ignorant of the specifics why that is (they’re generated using curandMakeMTGP32KernelState from the cuRAND API).
The happy news is that with the PyTorch nightlies, you can profit from Syed Ahmed’s awesome work on the random generation in PyTorch. Since last week or so, master and the nightlies use only the Philox generator for CUDA and the problem should be gone. (Note that the only used state now is seed and offset from the Philox generator, so 2 64-bit numbers in the last 8 bytes of the state.)

Best regards

Thomas

Thank you!

For now, I am not considering to move to more recent versions/nightly Pytorch since I need to update my code. But, I will definitely do that for my next code.

Now, based on your current comment and previous
one, I tried to RESET the rng state directly using a handcrafted/random rng state.

I found the following:

  1. One can reset the rng state. (expected)
  2. Any rng with the right length can do that. (824016 for CUA, and 5048 for CPU)
  3. It is easy to set an rng state that generates inf numbers. This means setting a handcrafted rng state needs to be done with care.
  4. However, it seems like that if the rng does not like the handcrafted rng state (it does not raise an error), but it seems to ignore it, and goes by its own random rng state. rng state may have some predefined structure (in term of numbers).
  5. One can control only the initial rng state. After it has been fixed, the next rng state is random (because 4).
  6. To make sure that I am using a valid rng state, I store, at the FIRST run, the rng state of CUDA and CPU on disc . In the subsequent runs, I load the saved rng states, and set them. While passing a valid rng state around allows to obtain the same rng state over the iterations, the FIRST iteration of the SUBSEQUENT runs does not obey to the idea of fixed rng state (the rng ignore!!!).

Before you read the code/runs, do you think that there is another variable other than the rng state of CUA/CPU that the RNG of CUDA depends on? you will see clearly that even when using the exact same rng state for CUDA, it does not lead to the same rng state, in a sens: rngs1 = torch.set_rng_state(stored_rngs); rngs2 = torch.get_rng_state ();, you find that rngs1 is different than rngs2. I suspect that there is something else affecting CUDA rng.

Thanks!

Here is an example (the runs are not reproducible):

import copy
import os


import torch
import numpy as np


seed = 1
# reset the seed of Numpy to control its initial rng state.
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

use_cuda = True

# set cuda rng state using Numpy rng state or any other randomly generated state.
state_cuda = np.random.randint(0, np.iinfo(np.uint8).max + 1, 824016) * 0.  # 824016 is the expected size of thr rng
# state.
# A normal rng state is not random or it has some patterns between its internal numbers!!! random rng state can easily
# lead to generating inf value. Try setting the rng state all to 1s or using the above np.random.randint() to see
# that the generator breaks at inf test.

state_cpu = np.random.randint(0, np.iinfo(np.uint8).max + 1, 5048) * 0.

l_cuda = (torch.tensor(state_cuda).type(torch.uint8))
l_cpu = (torch.tensor(state_cpu).type(torch.uint8))

# use a VALIDE stored rng state
LOAD = True
filecuda = "prng/cuda.npy"
filecpu = "prng/cpu.npy"
if LOAD:
    if not os.path.exists(filecuda) and not os.path.exists(filecpu):
        l_cuda = (torch.cuda.get_rng_state(device=0))
        torch.save(l_cuda, filecuda)
        l_cpu = (torch.get_rng_state())
        torch.save(l_cpu, filecpu)
    else:
        arrcuda = torch.load(filecuda)
        l_cuda = (torch.tensor(arrcuda).type(torch.uint8))
        arrcpu = torch.load(filecpu)
        l_cpu = (torch.tensor(arrcpu).type(torch.uint8))

for i in range(10):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # force the cuda rng state to be equal to `state`.
    torch.cuda.set_rng_state(l_cuda, device=0)
    torch.set_rng_state(l_cpu)
    lastlcuda = copy.deepcopy(l_cuda)
    lastlcpu = copy.deepcopy(l_cpu)
    l_cuda = (torch.cuda.get_rng_state(device=0))
    st = bool((l_cuda == lastlcuda).all())
    print("Iter. {}: current (rngs.sum(){}) == previous (rngs.sum(){})? {} CUDA".format(
        i, l_cuda.sum(), lastlcuda.sum(), st))
    st = bool((l_cpu == lastlcpu).all())
    print("Iter. {}: current (rngs.sum(){}) == previous (rngs.sum(){})? {} CPU".format(
        i, l_cpu.sum(), lastlcpu.sum(), st))
    a = torch.rand(1000, device="cuda:0")  # torch.bernoulli(torch.full((300, 300), 0.5, device='cuda:0'))
    print(a.sum())

run 1: (store the rng state of CUDA and CPU)

Iter. 0: current (rngs.sum()55928053) == previous (rngs.sum()55928053)? True CUDA
Iter. 0: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(498.2127, device='cuda:0')
Iter. 1: current (rngs.sum()55928053) == previous (rngs.sum()55928053)? True CUDA
Iter. 1: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(498.2127, device='cuda:0')
Iter. 2: current (rngs.sum()55928053) == previous (rngs.sum()55928053)? True CUDA
Iter. 2: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(498.2127, device='cuda:0')
Iter. 3: current (rngs.sum()55928053) == previous (rngs.sum()55928053)? True CUDA
Iter. 3: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(498.2127, device='cuda:0')
Iter. 4: current (rngs.sum()55928053) == previous (rngs.sum()55928053)? True CUDA
Iter. 4: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(498.2127, device='cuda:0')
Iter. 5: current (rngs.sum()55928053) == previous (rngs.sum()55928053)? True CUDA
Iter. 5: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(498.2127, device='cuda:0')
Iter. 6: current (rngs.sum()55928053) == previous (rngs.sum()55928053)? True CUDA
Iter. 6: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(498.2127, device='cuda:0')
Iter. 7: current (rngs.sum()55928053) == previous (rngs.sum()55928053)? True CUDA
Iter. 7: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(498.2127, device='cuda:0')
Iter. 8: current (rngs.sum()55928053) == previous (rngs.sum()55928053)? True CUDA
Iter. 8: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(498.2127, device='cuda:0')
Iter. 9: current (rngs.sum()55928053) == previous (rngs.sum()55928053)? True CUDA
Iter. 9: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU

You see that everything is perfect (the rng state is the same, including the ietr. 0).

Now, if we run a second time where we reload the rng states, see what happens:

run 2: (load the stored rng state of CUDA and CPU)

Iter. 0: current (rngs.sum()55990853) == previous (rngs.sum()55928053)? **False** CUDA
Iter. 0: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(498.2127, device='cuda:0')
Iter. 1: current (rngs.sum()55990853) == previous (rngs.sum()55990853)? True CUDA
Iter. 1: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(498.2127, device='cuda:0')
Iter. 2: current (rngs.sum()55990853) == previous (rngs.sum()55990853)? True CUDA
Iter. 2: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(498.2127, device='cuda:0')
Iter. 3: current (rngs.sum()55990853) == previous (rngs.sum()55990853)? True CUDA
Iter. 3: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(498.2127, device='cuda:0')
Iter. 4: current (rngs.sum()55990853) == previous (rngs.sum()55990853)? True CUDA
Iter. 4: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(498.2127, device='cuda:0')
Iter. 5: current (rngs.sum()55990853) == previous (rngs.sum()55990853)? True CUDA
Iter. 5: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(498.2127, device='cuda:0')
Iter. 6: current (rngs.sum()55990853) == previous (rngs.sum()55990853)? True CUDA
Iter. 6: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(498.2127, device='cuda:0')
Iter. 7: current (rngs.sum()55990853) == previous (rngs.sum()55990853)? True CUDA
Iter. 7: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(498.2127, device='cuda:0')
Iter. 8: current (rngs.sum()55990853) == previous (rngs.sum()55990853)? True CUDA
Iter. 8: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(498.2127, device='cuda:0')
Iter. 9: current (rngs.sum()55990853) == previous (rngs.sum()55990853)? True CUDA
Iter. 9: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(498.2127, device='cuda:0')

At iter. 0, the loaded rng state for CUDA was not considered!!! in other words:

torch.cuda.set_rng_state(l_cuda, device=0)
rngs = (torch.cuda.get_rng_state(device=0))

leads to rngs different than l_cuda!!! any insight why?

If you wondering why I am trying to fix this and obtain the same rng state, it is because I am trying to make Pytorch reproducible over MultiGPUs. For that, I use:

  1. threading.Lock() to lock the zones that cause randomness so I can reseed each thread safely since they all share the same rng state. (caused by using threads in torch.nn.DataParallel who happen to share rng states between threads and mess things up).
  2. Re-seeding, resetting rngs state for each thread (in a thread-safe manner) to control the randomness of each thread.

I will post the code here, as soon as I obtain reproducible results on a real code.

For now, on a synthetic code (using simple instructions), I seem to be able to achieve reproducibility over multiGPU for the simple code. However, when I tried the idea on real code, it does not work ALL the time.

I followed your comment to pass rng states around instead of just re-seeding. For now, there is no luck. Sometimes, I obtain EXACTLY the same result over my real code. But, most of the time reproducibility fails. There, after some checking, I realized that there is something wrong with the rng states. Outside the threads (call of parallelized previous forward), rng state of the main thread is not deterministic eventhough I force it to be (in somehow, it does not care).

===================================================

                           NUMPY

===================================================
See how Numpy rng state can easily be controlled using only a seed.

Code:

import copy

import numpy as np

seed = 1
use_cuda = True
np.random.seed(seed)
l = np.random.get_state()
for i in range(10):
    np.random.seed(seed)
    lastl = copy.deepcopy(l)
    l = (np.random.get_state())
    print("Iter. {}: current (rngs.sum(){}) == previous (rngs.sum(){})? {}".format(
        i, l[1].sum(), lastl[1].sum(), bool((l[1] == lastl[1]).all())))
    a = np.random.uniform(1)
    print(a)

run 1

Iter. 0: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 1: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 2: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 3: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 4: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 5: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 6: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 7: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 8: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 9: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0

run 2

Iter. 0: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 1: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 2: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 3: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 4: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 5: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 6: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 7: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 8: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 9: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0

run 3

Iter. 0: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 1: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 2: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 3: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 4: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 5: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 6: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 7: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 8: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0
Iter. 9: current (rngs.sum()1275581757218) == previous (rngs.sum()1275581757218)? True
1.0

Total reproducibility.

Sorry for long message.

Thanks!

Note that the MTGP32 state includes a pointer to the kernel parameters (see curand_mtgp32.h). This is set to the right thing™ in the set procedure (in THCTensorRandom.cu) and the corresponding inputs are effectively ignored, but they aren’t zeroed in the get part (which would be needed for loops to be reliably idempotent).

It should not be a surprise that feeding manipulated values into the internal state wrecks havoc.

I must admit that I’m not terribly keen on digging into this further, given that it’s a solved problem with standardizing on the Philox generator.

Best regards

Thomas

I tried the nightly b. (https://download.pytorch.org/whl/nightly/cu100/torch_nightly-1.2.0.dev20190616-cp37-cp37m-linux_x86_64.whl)

  1. prng state issue is fixed: within the same run, the prng state does not change when the seed is fixed.
  2. Now, it is possible to fix the same prng state over different runs. (by using a fixed seed, or a stored prng state on disc, for instance).

Code:

import copy
import os


import torch
import numpy as np


seed = 1
# reset the seed of Numpy to control its initial rng state.
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

use_cuda = True

state_cpu = torch.get_rng_state()
state_cuda = torch.cuda.get_rng_state(device=0)

l_cuda = (torch.tensor(state_cuda).type(torch.uint8))
l_cpu = (torch.tensor(state_cpu).type(torch.uint8))

LOAD = True
filecuda = "prng/cuda.npy"
filecpu = "prng/cpu.npy"
if LOAD:
    if not os.path.exists(filecuda) and not os.path.exists(filecpu):
        l_cuda = (torch.cuda.get_rng_state(device=0))
        torch.save(l_cuda, filecuda)
        l_cpu = (torch.get_rng_state())
        torch.save(l_cpu, filecpu)
    else:
        arrcuda = torch.load(filecuda)
        l_cuda = (torch.tensor(arrcuda).type(torch.uint8))
        arrcpu = torch.load(filecpu)
        l_cpu = (torch.tensor(arrcpu).type(torch.uint8))

for i in range(10):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # force the cuda rng state to be equal to `state`.
    torch.cuda.set_rng_state(l_cuda, device=0)
    torch.set_rng_state(l_cpu)
    lastlcuda = copy.deepcopy(l_cuda)
    lastlcpu = copy.deepcopy(l_cpu)
    l_cuda = (torch.cuda.get_rng_state(device=0))
    st = bool((l_cuda == lastlcuda).all())
    print("Iter. {}: current (rngs.sum(){}) == previous (rngs.sum(){})? {} CUDA".format(
        i, l_cuda.sum(), lastlcuda.sum(), st))
    st = bool((l_cpu == lastlcpu).all())
    print("Iter. {}: current (rngs.sum(){}) == previous (rngs.sum(){})? {} CPU".format(
        i, l_cpu.sum(), lastlcpu.sum(), st))
    a = torch.rand(1000, device="cuda:0")  # torch.bernoulli(torch.full((300, 300), 0.5, device='cuda:0'))
    print(a.sum())

run 1: (store rng state on disc)

Iter. 0: current (rngs.sum()210120001) == previous (rngs.sum()210120001)? True CUDA
Iter. 0: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(488.7892, device='cuda:0')
Iter. 1: current (rngs.sum()210120001) == previous (rngs.sum()210120001)? True CUDA
Iter. 1: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(488.7892, device='cuda:0')
Iter. 2: current (rngs.sum()210120001) == previous (rngs.sum()210120001)? True CUDA
Iter. 2: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(488.7892, device='cuda:0')
Iter. 3: current (rngs.sum()210120001) == previous (rngs.sum()210120001)? True CUDA
Iter. 3: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(488.7892, device='cuda:0')
Iter. 4: current (rngs.sum()210120001) == previous (rngs.sum()210120001)? True CUDA
Iter. 4: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(488.7892, device='cuda:0')
Iter. 5: current (rngs.sum()210120001) == previous (rngs.sum()210120001)? True CUDA
Iter. 5: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(488.7892, device='cuda:0')
Iter. 6: current (rngs.sum()210120001) == previous (rngs.sum()210120001)? True CUDA
Iter. 6: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(488.7892, device='cuda:0')
Iter. 7: current (rngs.sum()210120001) == previous (rngs.sum()210120001)? True CUDA
Iter. 7: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(488.7892, device='cuda:0')
Iter. 8: current (rngs.sum()210120001) == previous (rngs.sum()210120001)? True CUDA
Iter. 8: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(488.7892, device='cuda:0')
Iter. 9: current (rngs.sum()210120001) == previous (rngs.sum()210120001)? True CUDA
Iter. 9: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(488.7892, device='cuda:0')

run 2: use the stored prng state. (same results when using the same seed without storing anything.)

Iter. 0: current (rngs.sum()210120001) == previous (rngs.sum()210120001)? **True** CUDA
Iter. 0: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(488.7892, device='cuda:0')
Iter. 1: current (rngs.sum()210120001) == previous (rngs.sum()210120001)? True CUDA
Iter. 1: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(488.7892, device='cuda:0')
Iter. 2: current (rngs.sum()210120001) == previous (rngs.sum()210120001)? True CUDA
Iter. 2: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(488.7892, device='cuda:0')
Iter. 3: current (rngs.sum()210120001) == previous (rngs.sum()210120001)? True CUDA
Iter. 3: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(488.7892, device='cuda:0')
Iter. 4: current (rngs.sum()210120001) == previous (rngs.sum()210120001)? True CUDA
Iter. 4: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(488.7892, device='cuda:0')
Iter. 5: current (rngs.sum()210120001) == previous (rngs.sum()210120001)? True CUDA
Iter. 5: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(488.7892, device='cuda:0')
Iter. 6: current (rngs.sum()210120001) == previous (rngs.sum()210120001)? True CUDA
Iter. 6: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(488.7892, device='cuda:0')
Iter. 7: current (rngs.sum()210120001) == previous (rngs.sum()210120001)? True CUDA
Iter. 7: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(488.7892, device='cuda:0')
Iter. 8: current (rngs.sum()210120001) == previous (rngs.sum()210120001)? True CUDA
Iter. 8: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(488.7892, device='cuda:0')
Iter. 9: current (rngs.sum()210120001) == previous (rngs.sum()210120001)? True CUDA
Iter. 9: current (rngs.sum()316421) == previous (rngs.sum()316421)? True CPU
tensor(488.7892, device='cuda:0')

Thanks!