Working with RNG states on CPU and GPU

I have a bit of confusion surrounding PyTorch’s RNGs and how/when the CPU or GPU/CUDA generator is used. I have a model that saves the RNG state internally and essentially replicates the behavior of torch.random.fork_rng(). I do this as new layers may be added during the training process and I need to deterministically initialize their weights(the whole thing is painfully complex but that kinda sums it up). At a high level, my code stashes the main RNG state from torch.get_rng_state() in a local variable and then swaps it out with the state that it keeps. E.g.,

master_rng_state = torch.get_rng_state()  # Get the "master" RNG state
torch.set_rng_state(self._rng_state)  # This is my "forked" RNG state

... do some initialization stuff here

self._rng_state = torch.get_rng_state()  # Update the forked RNG state
torch.set_rng_state(master_rng_state)  # Restore the master RNG state

In this example, my forked state self._rng_state is taken from torch.get_rng_state() when the model is first initialized.

This works quite well when running on the CPU, but I have a problem when I run this on the GPU.
When I try to set the RNG state with my forked state: torch.set_rng_state(self._rng_state) it turns out the forked state has been moved to the GPU. Torch gives me a TypeError: expected a torch.ByteTensor, but got torch.cuda.ByteTensor. Does anyone know why this happens? Cuda is never called on the forked state and all the models are moved to CPU memory before any of this happens.

Did you register the self._rng_state as a buffer or parameters?

I actually figured out what the issue was. I am appending the RNG state to the models state dict when saving a checkpoint. So when loading from a checkpoint, the map_location was sending it to the GPU. Silly oversight on my part.

Thanks for responding anyway though! the state tensor being a parameter would have been a good explanation :slight_smile:

2 Likes

That saved me too!

Thanks! :+1: