Hi All!
A couple more data points on multinomial()
's consumption of random
numbers:
First, multinomial()
can be run in a single-precision or double-precision
mode by passing in an input
probability tensor of type float
and double
,
respectively. (In either case, the multinomial samples returned are of type
long
.)
Using pytorch version 1.13.1, both single- and double-precision multinomial
making both single-sample and multiple-sample calls consume the same
number of rand
s (two per sample) and return the same sample values.
The multiple-sample calls on version 2.0.0 behave the same way.
However, single-precision, single-sample calls on 2.0.0 consume only
one rand
per sample (while double-precision, single-sample calls
consume two).
Also note that in my test on 2.0.0, the double-precision, single-sample calls
return every other of the samples that the single-precision, single-sample
calls return. (This may be an artifact of the fact that the samples are
discrete and that it is unlikely that a single-precision value generated from
a single rand
would fall close enough to the border of a probability bin
that the nearly-equal double-precision value generated from two rand
s
would fall instead in the neighboring bin.)
Last, the 2.0.0 single-sample performance bug affects both the
single-precision and double-precision calls, in that both are about
ten times slower than the analogous calls on 1.13.1.
Here is a test script and its 1.13.1 and 2.0.0 outputs:
import torch
print (torch.__version__)
import time
seed = 2023
p = torch.ones (100)
n = 8
nBig = 1000000
g = torch.Generator().manual_seed (seed)
rnds = torch.rand (64, generator = g) # to see how many rands are consumed
print ('rnds:', rnds)
for dtype in [torch.float, torch.double]:
q = p.to (dtype)
print ('multi-sample, dtype:', dtype)
for i in range (1, n + 1):
g = torch.Generator().manual_seed (seed)
mlt = torch.multinomial (q, num_samples = i, replacement = True, generator = g)
r = torch.rand (1, generator = g)
print ('i:', i, ' r:', r, ' (r == rnds).nonzero():', (r == rnds).nonzero())
if dtype == torch.float:
mlt_multi_float = mlt.tolist()
else:
mlt_multi_double = mlt.tolist()
g = torch.Generator().manual_seed (seed)
t0 = time.time()
mlt = mlt = torch.multinomial (q, num_samples = nBig, replacement = True, generator = g)
print ('nBig:', nBig, ' time =', time.time() - t0)
print ('single-sample, dtype:', dtype)
for i in range (1, n + 1):
g = torch.Generator().manual_seed (seed)
mlt_single = []
for j in range (i):
mlt = torch.multinomial (q, num_samples = 1, replacement = True, generator = g)
mlt_single.append (mlt.item())
r = torch.rand (1, generator = g)
print ('i:', i, ' r:', r, ' (r == rnds).nonzero():', (r == rnds).nonzero())
if dtype == torch.float:
mlt_single_float = mlt_single
else:
mlt_single_double = mlt_single
g = torch.Generator().manual_seed (seed)
t0 = time.time()
for j in range (nBig):
mlt = torch.multinomial (q, num_samples = 1, replacement = True, generator = g)
print ('nBig:', nBig, ' time =', time.time() - t0)
print ('mlt_multi_float: ', mlt_multi_float)
print ('mlt_multi_double: ', mlt_multi_double)
print ('mlt_single_float: ', mlt_single_float)
print ('mlt_single_double: ', mlt_single_double)
1.13.1
rnds: tensor([0.4290, 0.7201, 0.9481, 0.4797, 0.5414, 0.9906, 0.4086, 0.2183, 0.1834,
0.2852, 0.7813, 0.1048, 0.6550, 0.8375, 0.1823, 0.5239, 0.2432, 0.9644,
0.5034, 0.0320, 0.8316, 0.3807, 0.3539, 0.2114, 0.9839, 0.6632, 0.7001,
0.0155, 0.3840, 0.7968, 0.4917, 0.4324, 0.5174, 0.6913, 0.1628, 0.5692,
0.0938, 0.3054, 0.1259, 0.7719, 0.6046, 0.9558, 0.0861, 0.7213, 0.0747,
0.0035, 0.4003, 0.5900, 0.1179, 0.7419, 0.6117, 0.9598, 0.4614, 0.9241,
0.1411, 0.9587, 0.9025, 0.5650, 0.2162, 0.9765, 0.6009, 0.5165, 0.3156,
0.3671])
multi-sample, dtype: torch.float32
i: 1 r: tensor([0.9481]) (r == rnds).nonzero(): tensor([[2]])
i: 2 r: tensor([0.5414]) (r == rnds).nonzero(): tensor([[4]])
i: 3 r: tensor([0.4086]) (r == rnds).nonzero(): tensor([[6]])
i: 4 r: tensor([0.1834]) (r == rnds).nonzero(): tensor([[8]])
i: 5 r: tensor([0.7813]) (r == rnds).nonzero(): tensor([[10]])
i: 6 r: tensor([0.6550]) (r == rnds).nonzero(): tensor([[12]])
i: 7 r: tensor([0.1823]) (r == rnds).nonzero(): tensor([[14]])
i: 8 r: tensor([0.2432]) (r == rnds).nonzero(): tensor([[16]])
nBig: 1000000 time = 0.04687786102294922
single-sample, dtype: torch.float32
i: 1 r: tensor([0.9481]) (r == rnds).nonzero(): tensor([[2]])
i: 2 r: tensor([0.5414]) (r == rnds).nonzero(): tensor([[4]])
i: 3 r: tensor([0.4086]) (r == rnds).nonzero(): tensor([[6]])
i: 4 r: tensor([0.1834]) (r == rnds).nonzero(): tensor([[8]])
i: 5 r: tensor([0.7813]) (r == rnds).nonzero(): tensor([[10]])
i: 6 r: tensor([0.6550]) (r == rnds).nonzero(): tensor([[12]])
i: 7 r: tensor([0.1823]) (r == rnds).nonzero(): tensor([[14]])
i: 8 r: tensor([0.2432]) (r == rnds).nonzero(): tensor([[16]])
nBig: 1000000 time = 4.480212926864624
multi-sample, dtype: torch.float64
i: 1 r: tensor([0.9481]) (r == rnds).nonzero(): tensor([[2]])
i: 2 r: tensor([0.5414]) (r == rnds).nonzero(): tensor([[4]])
i: 3 r: tensor([0.4086]) (r == rnds).nonzero(): tensor([[6]])
i: 4 r: tensor([0.1834]) (r == rnds).nonzero(): tensor([[8]])
i: 5 r: tensor([0.7813]) (r == rnds).nonzero(): tensor([[10]])
i: 6 r: tensor([0.6550]) (r == rnds).nonzero(): tensor([[12]])
i: 7 r: tensor([0.1823]) (r == rnds).nonzero(): tensor([[14]])
i: 8 r: tensor([0.2432]) (r == rnds).nonzero(): tensor([[16]])
nBig: 1000000 time = 0.050467491149902344
single-sample, dtype: torch.float64
i: 1 r: tensor([0.9481]) (r == rnds).nonzero(): tensor([[2]])
i: 2 r: tensor([0.5414]) (r == rnds).nonzero(): tensor([[4]])
i: 3 r: tensor([0.4086]) (r == rnds).nonzero(): tensor([[6]])
i: 4 r: tensor([0.1834]) (r == rnds).nonzero(): tensor([[8]])
i: 5 r: tensor([0.7813]) (r == rnds).nonzero(): tensor([[10]])
i: 6 r: tensor([0.6550]) (r == rnds).nonzero(): tensor([[12]])
i: 7 r: tensor([0.1823]) (r == rnds).nonzero(): tensor([[14]])
i: 8 r: tensor([0.2432]) (r == rnds).nonzero(): tensor([[16]])
nBig: 1000000 time = 4.6825549602508545
mlt_multi_float: [43, 58, 33, 26, 46, 25, 23, 45]
mlt_multi_double: [43, 58, 33, 26, 46, 25, 23, 45]
mlt_single_float: [43, 58, 33, 26, 46, 25, 23, 45]
mlt_single_double: [43, 58, 33, 26, 46, 25, 23, 45]
2.0.0
rnds: tensor([0.4290, 0.7201, 0.9481, 0.4797, 0.5414, 0.9906, 0.4086, 0.2183, 0.1834,
0.2852, 0.7813, 0.1048, 0.6550, 0.8375, 0.1823, 0.5239, 0.2432, 0.9644,
0.5034, 0.0320, 0.8316, 0.3807, 0.3539, 0.2114, 0.9839, 0.6632, 0.7001,
0.0155, 0.3840, 0.7968, 0.4917, 0.4324, 0.5174, 0.6913, 0.1628, 0.5692,
0.0938, 0.3054, 0.1259, 0.7719, 0.6046, 0.9558, 0.0861, 0.7213, 0.0747,
0.0035, 0.4003, 0.5900, 0.1179, 0.7419, 0.6117, 0.9598, 0.4614, 0.9241,
0.1411, 0.9587, 0.9025, 0.5650, 0.2162, 0.9765, 0.6009, 0.5165, 0.3156,
0.3671])
multi-sample, dtype: torch.float32
i: 1 r: tensor([0.7201]) (r == rnds).nonzero(): tensor([[1]])
i: 2 r: tensor([0.5414]) (r == rnds).nonzero(): tensor([[4]])
i: 3 r: tensor([0.4086]) (r == rnds).nonzero(): tensor([[6]])
i: 4 r: tensor([0.1834]) (r == rnds).nonzero(): tensor([[8]])
i: 5 r: tensor([0.7813]) (r == rnds).nonzero(): tensor([[10]])
i: 6 r: tensor([0.6550]) (r == rnds).nonzero(): tensor([[12]])
i: 7 r: tensor([0.1823]) (r == rnds).nonzero(): tensor([[14]])
i: 8 r: tensor([0.2432]) (r == rnds).nonzero(): tensor([[16]])
nBig: 1000000 time = 0.05341219902038574
single-sample, dtype: torch.float32
i: 1 r: tensor([0.7201]) (r == rnds).nonzero(): tensor([[1]])
i: 2 r: tensor([0.9481]) (r == rnds).nonzero(): tensor([[2]])
i: 3 r: tensor([0.4797]) (r == rnds).nonzero(): tensor([[3]])
i: 4 r: tensor([0.5414]) (r == rnds).nonzero(): tensor([[4]])
i: 5 r: tensor([0.9906]) (r == rnds).nonzero(): tensor([[5]])
i: 6 r: tensor([0.4086]) (r == rnds).nonzero(): tensor([[6]])
i: 7 r: tensor([0.2183]) (r == rnds).nonzero(): tensor([[7]])
i: 8 r: tensor([0.1834]) (r == rnds).nonzero(): tensor([[8]])
nBig: 1000000 time = 45.47208642959595
multi-sample, dtype: torch.float64
i: 1 r: tensor([0.9481]) (r == rnds).nonzero(): tensor([[2]])
i: 2 r: tensor([0.5414]) (r == rnds).nonzero(): tensor([[4]])
i: 3 r: tensor([0.4086]) (r == rnds).nonzero(): tensor([[6]])
i: 4 r: tensor([0.1834]) (r == rnds).nonzero(): tensor([[8]])
i: 5 r: tensor([0.7813]) (r == rnds).nonzero(): tensor([[10]])
i: 6 r: tensor([0.6550]) (r == rnds).nonzero(): tensor([[12]])
i: 7 r: tensor([0.1823]) (r == rnds).nonzero(): tensor([[14]])
i: 8 r: tensor([0.2432]) (r == rnds).nonzero(): tensor([[16]])
nBig: 1000000 time = 0.046894073486328125
single-sample, dtype: torch.float64
i: 1 r: tensor([0.9481]) (r == rnds).nonzero(): tensor([[2]])
i: 2 r: tensor([0.5414]) (r == rnds).nonzero(): tensor([[4]])
i: 3 r: tensor([0.4086]) (r == rnds).nonzero(): tensor([[6]])
i: 4 r: tensor([0.1834]) (r == rnds).nonzero(): tensor([[8]])
i: 5 r: tensor([0.7813]) (r == rnds).nonzero(): tensor([[10]])
i: 6 r: tensor([0.6550]) (r == rnds).nonzero(): tensor([[12]])
i: 7 r: tensor([0.1823]) (r == rnds).nonzero(): tensor([[14]])
i: 8 r: tensor([0.2432]) (r == rnds).nonzero(): tensor([[16]])
nBig: 1000000 time = 42.14270281791687
mlt_multi_float: [43, 58, 33, 26, 46, 25, 23, 45]
mlt_multi_double: [43, 58, 33, 26, 46, 25, 23, 45]
mlt_single_float: [49, 23, 38, 85, 5, 3, 71, 25]
mlt_single_double: [23, 85, 3, 25, 28, 9, 39, 60]
Best.
K. Frank