I have a bit of confusion surrounding PyTorch’s RNGs and how/when the CPU or GPU/CUDA generator is used. I have a model that saves the RNG state internally and essentially replicates the behavior of
torch.random.fork_rng(). I do this as new layers may be added during the training process and I need to deterministically initialize their weights(the whole thing is painfully complex but that kinda sums it up). At a high level, my code stashes the main RNG state from
torch.get_rng_state() in a local variable and then swaps it out with the state that it keeps. E.g.,
master_rng_state = torch.get_rng_state() # Get the "master" RNG state torch.set_rng_state(self._rng_state) # This is my "forked" RNG state ... do some initialization stuff here self._rng_state = torch.get_rng_state() # Update the forked RNG state torch.set_rng_state(master_rng_state) # Restore the master RNG state
In this example, my forked state
self._rng_state is taken from
torch.get_rng_state() when the model is first initialized.
This works quite well when running on the CPU, but I have a problem when I run this on the GPU.
When I try to set the RNG state with my forked state:
torch.set_rng_state(self._rng_state) it turns out the forked state has been moved to the GPU. Torch gives me a
TypeError: expected a torch.ByteTensor, but got torch.cuda.ByteTensor. Does anyone know why this happens? Cuda is never called on the forked state and all the models are moved to CPU memory before any of this happens.