While following Andrej Karpathy’s video on makemore, I ran into a puzzling result from the call to multinomial.
TL;DR: Using a fixed seed, I am getting a different result for result at position [0] when num_samples is 1 than with any other value. I would expect that the result would be identical. See below code example with output. Is this expected?
Output:
Python version: 3.8.10 (default, Nov 14 2022, 12:59:47)
[GCC 9.4.0]
PyTorch version: 2.0.0+cu117
tensor([0.0000, 0.1377, 0.1784, 0.2266, 0.2793, 0.3271, 0.3401, 0.3610, 0.3883,
0.4068, 0.4824, 0.5749, 0.6240, 0.7032, 0.7390, 0.7513, 0.7673, 0.7702,
0.8214, 0.8855, 0.9264, 0.9288, 0.9405, 0.9501, 0.9543, 0.9710, 1.0000])
With num_samples = 1
Rand tensor([0.7081])
P tensor([0.0000, 0.1377, 0.0408, 0.0481, 0.0528, 0.0478, 0.0130, 0.0209, 0.0273,
0.0184, 0.0756, 0.0925, 0.0491, 0.0792, 0.0358, 0.0123, 0.0161, 0.0029,
0.0512, 0.0642, 0.0408, 0.0024, 0.0117, 0.0096, 0.0042, 0.0167, 0.0290])
tensor([10])
Cumulative 0.4823775589466095
Value 0.07560952752828598
With num_samples = 2
Rand tensor([0.7081, 0.3542])
P tensor([0.0000, 0.1377, 0.0408, 0.0481, 0.0528, 0.0478, 0.0130, 0.0209, 0.0273,
0.0184, 0.0756, 0.0925, 0.0491, 0.0792, 0.0358, 0.0123, 0.0161, 0.0029,
0.0512, 0.0642, 0.0408, 0.0024, 0.0117, 0.0096, 0.0042, 0.0167, 0.0290])
tensor([13, 19])
Cumulative 0.7031810879707336
Value 0.07923079282045364
With num_samples = 3
Rand tensor([0.7081, 0.3542, 0.1054])
P tensor([0.0000, 0.1377, 0.0408, 0.0481, 0.0528, 0.0478, 0.0130, 0.0209, 0.0273,
0.0184, 0.0756, 0.0925, 0.0491, 0.0792, 0.0358, 0.0123, 0.0161, 0.0029,
0.0512, 0.0642, 0.0408, 0.0024, 0.0117, 0.0096, 0.0042, 0.0167, 0.0290])
tensor([13, 19, 14])
Cumulative 0.7031810879707336
Value 0.07923079282045364
Sample code:
import sys
import torch
print(f"Python version: {sys.version}")
print(f"PyTorch version: {torch.__version__}")
# This is a simplified version of the code example from makemore
# I left it as close as possible to the scenario I detected this anomaly
N = torch.zeros((1, 27), dtype = torch.int32)
values = list([0, 4410, 1306, 1542, 1690, 1531, 417, 669, 874, 591, 2422, 2963,
1572, 2538, 1146, 394, 515, 92, 1639, 2055, 1308, 78, 376, 307,
134, 535, 929])
for i,x in enumerate(values):
N[0, i] = x
p = N[0].float()
p = p / p.sum()
# This is a shorthand cumulative function to show where the breakpoints "should" be
cp = torch.cumsum(p, dim=0)
print(cp)
# This function ensures there is no variation based on the scenario being executed
def sanity_check(num_samples):
print(f"With num_samples = {num_samples}")
g = torch.Generator().manual_seed(2147483647)
print(f"Rand {torch.rand(num_samples, generator = g)}")
print(f"P {p}")
g = torch.Generator().manual_seed(2147483647)
output = torch.multinomial(p, num_samples=num_samples, replacement=True, generator=g)
ix = output[0]
print(output)
print(f"Cumulative {cp[ix]}")
print(f"Value {p[ix]}")
print()
sanity_check(1)
sanity_check(2)
sanity_check(3)
# Other tests are omitted. All produced the same result for [0]: 13. Only sanity_check(1) produces 10.