`torch.multinomial`: sampling once vs in batches gives different results

Hi everyone,

I assumed that the following two functions would produce exactly the same results, as they’re equivalent to sampling with the concatenated weights as in draw_once:

def draw_once(seed: int):
    torch.manual_seed(seed)
    weights = torch.arange(15).reshape(3, 5).float() * 100
    return torch.multinomial(torch.vstack([weights, weights, weights]), 1).squeeze(1)


def draw_many(seed: int):
    torch.manual_seed(seed)
    weights = torch.arange(15).reshape(3, 5).float() * 100
    return torch.vstack([torch.multinomial(weights, 1) for _ in range(3)]).squeeze(1)


# draw_once(0)
# tensor([2, 0, 1, 4, 1, 4, 2, 0, 1])
# draw_many(0)
# tensor([2, 0, 1, 3, 3, 0, 4, 1, 3])


once = torch.vstack([draw_once(i) for i in range(100_00)])
many = torch.vstack([draw_many(i) for i in range(100_00)])

once.float().std(axis=0) - many.float().std(axis=0)
once.float().mean(axis=0) - many.float().mean(axis=0)

however, sampling results in different, i.e. means are different beyond 5 sigma everywhere but in the first drawn sample, where the sample is exactly the same every time.

What can be the reason for this? I’ve also tried running with a newly instantiated Generator, but I get the same results – first drawn sample is the same, others aren’t.

I’m unsure if I understand your question correctly, but seeding the pseudorandom number generator (PRNG) will guarantee to return the same sequence of random numbers for the same calls into the PRNG. Our two approaches use different calls into multinomial and are thus not guaranteed to return the same sequence of random numbers.

you mean that in between the calls to multinomial in the loop can be some internal calls to the PRNG that I’m not calling explicitly, but they still affect the sequence?

It makes sense to me, though that the first n samples do match doesn’t really make sense.

Is there a way to look at the PRNG’s state, and which numbers are being generated that don’t end up in the multinomial?

No, I mean the calls themselves are not the same and thus not guaranteed to return the same values. E.g. the calls would advance the offset of the PRNG in different ways and the next call could thus return a different output.