Pass a generator to torch.nn.init functions


RNG functions like torch.normal and torch.Tensor.normal_ allow the caller to pass a generator object. But other functions, like torch.Tensor.uniform_, and all the methods in torch.nn.init, (as well as the modules’ reset_parameters) don’t.

Being able to pass a generator object is needed in cases where determinism and consistency are required. Even though calling the global torch.manual_seed can sometimes be used as a workaround, it doesn’t shield the user from issues such as concurrently initializing independent networks, or a third-party calling that function outside of the user’s control.

What’s the appetite for adding an generator argument to all those functions, the same way torch.Tensor.normal_ already does? What other solutions would there be to avoid relying on a global variable?


I think this would be better in an issue on github, but prima facie it makes sense to me.
Note that the modules have the broader issue of having not-so-good-default-init choices. (There is an issue on it and a stale PR or so.)

Best regards


Ok, happy to log this on GitHub. Thanks for the response.