Hi Sam!
I think I understand what you are asking …
Pytorch, roughly speaking, has a global random number generator that
is used for all random operations. So when you call multinomial()
, it
“consumes” some random numbers so that future random operations
become different (but should be statistically equivalent). The “random
operations” in your training code could be things like random initialization
of model weights, Dropout
operations, or a RandomSampler
in a
DataLoader
.
One approach would be to pre-calculate the multinomial values you will
need before setting your random seed and starting your training.
This is illustrated by this script:
import torch
torch.__version__
torch.random.manual_seed (2020)
torch.randn ((5,)) # desired reproducible result
torch.random.manual_seed (2020)
torch.ones(10).multinomial(num_samples=2, replacement=False) # this uses up some random numbers
torch.randn ((5,)) # therefore changing this result
torch.random.manual_seed (1010)
cached_multinomial = torch.ones(10).multinomial(num_samples=2, replacement=False) # cache some values for future use
torch.random.manual_seed (2020)
cached_multinomial # use cached values for whatever purpose
torch.randn ((5,)) # doesn't mess up desired result
Here is its output:
>>> import torch
>>> torch.__version__
'1.6.0'
>>>
>>> torch.random.manual_seed (2020)
<torch._C.Generator object at 0x7fe662956910>
>>> torch.randn ((5,)) # desired reproducible result
tensor([ 1.2372, -0.9604, 1.5415, -0.4079, 0.8806])
>>>
>>> torch.random.manual_seed (2020)
<torch._C.Generator object at 0x7fe662956910>
>>> torch.ones(10).multinomial(num_samples=2, replacement=False) # this uses up some random numbers
tensor([7, 9])
>>> torch.randn ((5,)) # therefore changing this result
tensor([-0.3136, 0.6418, 1.1961, 0.9936, 1.0911])
>>>
>>> torch.random.manual_seed (1010)
<torch._C.Generator object at 0x7fe662956910>
>>> cached_multinomial = torch.ones(10).multinomial(num_samples=2, replacement=False) # cache some values for future use
>>>
>>> torch.random.manual_seed (2020)
<torch._C.Generator object at 0x7fe662956910>
>>> cached_multinomial # use cached values for whatever purpose
tensor([2, 0])
>>> torch.randn ((5,)) # doesn't mess up desired result
tensor([ 1.2372, -0.9604, 1.5415, -0.4079, 0.8806])
You could also instantiate a non-global Generator object that you then
pass into your multinomial()
calls so that global random numbers
won’t be consumed.
Best.
K. Frank