Hi Forum!
@jeffc has found a bug in torch.multinomial()
– see his thread
Duplicate results due to lack of entropy using multinomial without replacement?
(@ptrblck, if this looks like a bug to you, could you pass it on to the
powers that be? I haven’t filed this as a github issue because I still
haven’t set up a new github account.)
The concept is that you can use multinomial()
without replacement
to generate random permutations and that these permutations should
be uniformly generated. (That is, any specific permutation of n
“letters” should be generated with probability 1 / n!
.)
In @jeffc’s example. he generates permutations of 26
letters. Were
the permutations generated uniformly, it turns out that you would
have to generate about 10**13
permutations in order to have a 50%
chance of generating a duplicate, but duplicates show up in sets of
only tens of thousands of samples.
(I don’t have any intuition about what kind of flaw in multinomial()
's
distribution would lead to this effect. I also don’t see any coarse flaw
in multinomial()
's distribution, but somehow generating permutations
is picking up on some sort of minor fishiness.)
Here is a script that reproduces the bug @jeffc found. It also uses
randperm()
to generate permutations and shows that these are
much more uniform.
import torch
print (torch.__version__)
high_bits_for_seed = 16000000000000000000 # to use "good quality" seed
_ = torch.manual_seed (high_bits_for_seed + 2024)
n = 26 # permutations of 26 "letters"
k = 1000000 # number of permutations to generate
print ('n:', n)
print ('k:', k)
# generate random permutations using multinomial() -- not quite uniform
prob = torch.ones (n)
dups_mult = 0
perm_counts_mult = {}
for _ in range (k):
p = tuple (torch.multinomial (prob, n, replacement=False).tolist())
if p in perm_counts_mult:
dups_mult += 1
perm_counts_mult[p] += 1
else:
perm_counts_mult[p] = 1
print ('duplicate multinomial perms: ', dups_mult)
print ('multiple multinomial perms: ', (torch.tensor (list (perm_counts_mult.values())) > 1).sum().item())
print ('max of perm_counts_mult: ', torch.tensor (list (perm_counts_mult.values())).max().item())
print ('len (perm_counts_mult): ', len (perm_counts_mult))
_ = torch.manual_seed (high_bits_for_seed + 2024) # just to be consistent
# generate random permutations using randperm() -- much more uniform
dups_rand = 0
perm_counts_rand = {}
for _ in range (k):
p = tuple (torch.randperm (n).tolist())
if p in perm_counts_rand:
dups_rand += 1
perm_counts_rand[p] += 1
else:
perm_counts_rand[p] = 1
print ('duplicate randperm perms: ', dups_rand)
print ('multiple randperm perms: ', (torch.tensor (list (perm_counts_rand.values())) > 1).sum().item())
print ('max of perm_counts_rand: ', torch.tensor (list (perm_counts_rand.values())).max().item())
print ('len (perm_counts_rand): ', len (perm_counts_rand))
And here is its output:
2.4.0
n: 26
k: 1000000
duplicate multinomial perms: 207
multiple multinomial perms: 207
max of perm_counts_mult: 2
len (perm_counts_mult): 999793
duplicate randperm perms: 0
multiple randperm perms: 0
max of perm_counts_rand: 1
len (perm_counts_rand): 1000000
I don’t really think that it’s the cause, but this multinomial() github issue
might be related. Note, passing the input
probabilities as float64
(which would presumably cause multinomial()
to process its
probability “buckets” in double precision) does not seem to help the
situation.
Best.
K. Frank