I have a bit of confusion surrounding PyTorch’s RNGs and how/when the CPU or GPU/CUDA generator is used. I have a model that saves the RNG state internally and essentially replicates the behavior of torch.random.fork_rng()
. I do this as new layers may be added during the training process and I need to deterministically initialize their weights(the whole thing is painfully complex but that kinda sums it up). At a high level, my code stashes the main RNG state from torch.get_rng_state()
in a local variable and then swaps it out with the state that it keeps. E.g.,
master_rng_state = torch.get_rng_state() # Get the "master" RNG state
torch.set_rng_state(self._rng_state) # This is my "forked" RNG state
... do some initialization stuff here
self._rng_state = torch.get_rng_state() # Update the forked RNG state
torch.set_rng_state(master_rng_state) # Restore the master RNG state
In this example, my forked state self._rng_state
is taken from torch.get_rng_state()
when the model is first initialized.
This works quite well when running on the CPU, but I have a problem when I run this on the GPU.
When I try to set the RNG state with my forked state: torch.set_rng_state(self._rng_state)
it turns out the forked state has been moved to the GPU. Torch gives me a TypeError: expected a torch.ByteTensor, but got torch.cuda.ByteTensor
. Does anyone know why this happens? Cuda is never called on the forked state and all the models are moved to CPU memory before any of this happens.