Strange behavior with torch.multinomial and seeding of the random generator

I understand that seeding random generators in torch will result in same sequence of random numbers getting generated. However, this behavior seems to be wrongly applied while using torch.multinomial. Here’s an example code:

x = torch.tensor([7.0373e-02, 4.2156e-01, 1.6142e-02, 5.5568e-06, 8.9858e-06, 5.9234e-02,
         3.5102e-03, 1.7107e-05, 8.4964e-06, 8.9921e-02, 7.4014e-03, 2.0311e-04,
         2.7425e-02, 1.4030e-01, 8.1232e-03, 4.0852e-03, 2.7123e-02, 9.4809e-03,
         9.8476e-03, 1.3012e-03, 4.7310e-06, 2.5187e-06, 9.8657e-02, 2.1528e-05,
         3.5292e-05, 1.5198e-07, 5.2092e-03])
y = torch.tensor([[2.9672e-02, 3.5392e-01, 1.6525e-02, 3.6541e-06, 4.1269e-06, 5.7287e-02,
         8.0020e-04, 1.3440e-05, 1.1609e-04, 1.7075e-01, 3.6608e-03, 1.3809e-04,
         4.4828e-02, 1.1024e-01, 3.5802e-03, 1.9378e-02, 3.3693e-02, 5.2079e-03,
         3.5408e-03, 6.1396e-02, 3.7051e-06, 3.8998e-06, 8.2420e-02, 1.0640e-05,
         2.4842e-05, 9.1395e-08, 2.7842e-03]])
z = torch.tensor([[1.4468e-02, 3.7064e-01, 2.0479e-02, 2.7340e-06, 1.9460e-06, 5.4278e-02,
         3.3113e-04, 1.0805e-05, 1.6410e-03, 1.6552e-01, 2.2685e-03, 9.0160e-05,
         4.6111e-02, 9.7330e-02, 3.1741e-03, 3.2047e-02, 3.9629e-02, 2.9939e-03,
         2.0106e-03, 7.1221e-02, 2.4662e-06, 3.6253e-06, 7.2946e-02, 7.7427e-06,
         1.9314e-05, 6.4524e-08, 2.7684e-03]])
w = torch.tensor([[8.9696e-03, 4.3454e-01, 2.2658e-02, 1.9803e-06, 1.0033e-06, 4.5699e-02,
         1.6099e-04, 7.1766e-06, 6.0996e-03, 1.2350e-01, 1.2642e-03, 5.0448e-05,
         4.0526e-02, 8.7067e-02, 2.5062e-03, 4.2755e-02, 4.0775e-02, 1.9256e-03,
         1.2747e-03, 6.8042e-02, 2.0754e-06, 2.7696e-06, 6.9855e-02, 5.1217e-06,
         1.3413e-05, 5.3659e-08, 2.2944e-03]])

for probs in [x, y, z, w]: 
    g = torch.Generator().manual_seed(2147483647)
    print(torch.multinomial(probs, num_samples=1, generator=g).item())

In this code snippet, I’m iterating over 4 different probability tensors with very different values. Beofre generating 1 sample from each, I’m reinitializing the generator. Now, my expectation is that while generator will cause same sequence of random numbers to be generated, it should at-least honour the distribution as per definition of multinomial. But it’s not happening. The random behavior is strictly deterministic to the point of ignoring any distribution being sent.

Hi Abhishek!

The results you are getting are to be expected (and correct), even
if they look a little surprising.

First, your results do depend on the “distribution being sent,” but it
just so happens that the random-generator seed you are using gives
unchanged results. Running with a different seed does (sometimes)
give different results.

But drilling down further, why do you tend to get the same results
for different input probabilities (when reseeding the generator with
the same seed)? This is because of how multinomial() works.

Under the hood, it takes a vector of samples from .exponential_(),
uses those samples to divide the probs you pass in, and returns
the argmax() of the result. So the index of smaller values in your
.exponential_() samples will tend to be returned as the result of
multinomial(). But, similarly, the index of larger values in your probs
tend to be returned as the result.

Sometimes the .exponential_() values – which stay the same – win
out, even though the probs differ.

Here is a script that illustrates these points:

import torch
print (torch.__version__)

x = torch.tensor([7.0373e-02, 4.2156e-01, 1.6142e-02, 5.5568e-06, 8.9858e-06, 5.9234e-02,
         3.5102e-03, 1.7107e-05, 8.4964e-06, 8.9921e-02, 7.4014e-03, 2.0311e-04,
         2.7425e-02, 1.4030e-01, 8.1232e-03, 4.0852e-03, 2.7123e-02, 9.4809e-03,
         9.8476e-03, 1.3012e-03, 4.7310e-06, 2.5187e-06, 9.8657e-02, 2.1528e-05,
         3.5292e-05, 1.5198e-07, 5.2092e-03])
y = torch.tensor([2.9672e-02, 3.5392e-01, 1.6525e-02, 3.6541e-06, 4.1269e-06, 5.7287e-02,
         8.0020e-04, 1.3440e-05, 1.1609e-04, 1.7075e-01, 3.6608e-03, 1.3809e-04,
         4.4828e-02, 1.1024e-01, 3.5802e-03, 1.9378e-02, 3.3693e-02, 5.2079e-03,
         3.5408e-03, 6.1396e-02, 3.7051e-06, 3.8998e-06, 8.2420e-02, 1.0640e-05,
         2.4842e-05, 9.1395e-08, 2.7842e-03])
z = torch.tensor([1.4468e-02, 3.7064e-01, 2.0479e-02, 2.7340e-06, 1.9460e-06, 5.4278e-02,
         3.3113e-04, 1.0805e-05, 1.6410e-03, 1.6552e-01, 2.2685e-03, 9.0160e-05,
         4.6111e-02, 9.7330e-02, 3.1741e-03, 3.2047e-02, 3.9629e-02, 2.9939e-03,
         2.0106e-03, 7.1221e-02, 2.4662e-06, 3.6253e-06, 7.2946e-02, 7.7427e-06,
         1.9314e-05, 6.4524e-08, 2.7684e-03])
w = torch.tensor([8.9696e-03, 4.3454e-01, 2.2658e-02, 1.9803e-06, 1.0033e-06, 4.5699e-02,
         1.6099e-04, 7.1766e-06, 6.0996e-03, 1.2350e-01, 1.2642e-03, 5.0448e-05,
         4.0526e-02, 8.7067e-02, 2.5062e-03, 4.2755e-02, 4.0775e-02, 1.9256e-03,
         1.2747e-03, 6.8042e-02, 2.0754e-06, 2.7696e-06, 6.9855e-02, 5.1217e-06,
         1.3413e-05, 5.3659e-08, 2.2944e-03])

for  iseed in range (5):
    print ('iseed:', iseed)
    for probs in [x, y, z, w]: 
        g = torch.Generator().manual_seed (2147483647 + iseed)
        s =  (torch.multinomial (probs, num_samples = 1, generator = g).item())
        print (s)

for  iseed in range (5):
    print ('iseed:', iseed)
    for probs in [x, y, z, w]: 
        g = torch.Generator().manual_seed (2147483647 + iseed)
        e = torch.empty_like (probs).exponential_ (generator = g)
        r = probs / e
        s = r.topk (1).indices.item()
        print (s, e.topk (8, largest = False).indices, probs.topk (8).indices)

And here is its output:

2.4.0
iseed: 0
13
13
13
13
iseed: 1
1
9
9
1
iseed: 2
0
12
12
12
iseed: 3
1
1
1
1
iseed: 4
13
13
19
1
iseed: 0
13 tensor([10, 22, 13, 21, 12, 15,  3,  4]) tensor([ 1, 13, 22,  9,  0,  5, 12, 16])
13 tensor([10, 22, 13, 21, 12, 15,  3,  4]) tensor([ 1,  9, 13, 22, 19,  5, 12, 16])
13 tensor([10, 22, 13, 21, 12, 15,  3,  4]) tensor([ 1,  9, 13, 22, 19,  5, 12, 16])
13 tensor([10, 22, 13, 21, 12, 15,  3,  4]) tensor([ 1,  9, 13, 22, 19,  5, 15, 16])
iseed: 1
1 tensor([ 3,  9, 14, 19,  6, 11, 12,  1]) tensor([ 1, 13, 22,  9,  0,  5, 12, 16])
9 tensor([ 3,  9, 14, 19,  6, 11, 12,  1]) tensor([ 1,  9, 13, 22, 19,  5, 12, 16])
9 tensor([ 3,  9, 14, 19,  6, 11, 12,  1]) tensor([ 1,  9, 13, 22, 19,  5, 12, 16])
1 tensor([ 3,  9, 14, 19,  6, 11, 12,  1]) tensor([ 1,  9, 13, 22, 19,  5, 15, 16])
iseed: 2
0 tensor([24, 12,  0, 17,  4,  3, 25,  8]) tensor([ 1, 13, 22,  9,  0,  5, 12, 16])
12 tensor([24, 12,  0, 17,  4,  3, 25,  8]) tensor([ 1,  9, 13, 22, 19,  5, 12, 16])
12 tensor([24, 12,  0, 17,  4,  3, 25,  8]) tensor([ 1,  9, 13, 22, 19,  5, 12, 16])
12 tensor([24, 12,  0, 17,  4,  3, 25,  8]) tensor([ 1,  9, 13, 22, 19,  5, 15, 16])
iseed: 3
1 tensor([ 8,  6,  3, 23,  1, 24,  5, 12]) tensor([ 1, 13, 22,  9,  0,  5, 12, 16])
1 tensor([ 8,  6,  3, 23,  1, 24,  5, 12]) tensor([ 1,  9, 13, 22, 19,  5, 12, 16])
1 tensor([ 8,  6,  3, 23,  1, 24,  5, 12]) tensor([ 1,  9, 13, 22, 19,  5, 12, 16])
1 tensor([ 8,  6,  3, 23,  1, 24,  5, 12]) tensor([ 1,  9, 13, 22, 19,  5, 15, 16])
iseed: 4
13 tensor([ 8, 20,  6, 10, 16, 19, 13, 26]) tensor([ 1, 13, 22,  9,  0,  5, 12, 16])
13 tensor([ 8, 20,  6, 10, 16, 19, 13, 26]) tensor([ 1,  9, 13, 22, 19,  5, 12, 16])
19 tensor([ 8, 20,  6, 10, 16, 19, 13, 26]) tensor([ 1,  9, 13, 22, 19,  5, 12, 16])
1 tensor([ 8, 20,  6, 10, 16, 19, 13, 26]) tensor([ 1,  9, 13, 22, 19,  5, 15, 16])

Best.

K. Frank