Torch.multinomial with generator is not determinstic under certain conditions

Hello,

I think there is a bug in torch.multinomial which sporadically appears when the number of input probabilities is large (e.g. greater than 10k, but gets worse as number gets larger) and replacement=True. I think it is also worse when the dynamic range of the probabilities is very high - i.e. a lot very very close to zero. I would love if this is user error, but I’m pretty sure it is not.

My guess is that it is some sort of numerical / floating point error?

What happens is that sometimes, the nth row of the returned indices is off by 1. e.g. in one trial, the 57th row out of 1000 might be index 57291 but in another trial it might end up as 57292 or 57290 (but never something far away like 1272).

Here is a reproducible - nb: we are on torch==2.5.1+cu118. If this is fixed in another version, apologies.

import torch
import torch.nn.functional as F


def run_test(
        device: torch.device,
        n_candidates: int,
        n_draws: int,
        n_trials: int,
        replacement: bool,
        band_width: float,
        band_offset: float,
        generator: torch.Generator,

):
    all_results = []
    for _ in range(n_trials):
        generator.manual_seed(42)

        # Stand-in for some process which randomly generates logits.
        # e.g. randomly sampling designs and running them through a classifier.
        logits = (
            torch.rand(n_candidates, device=device, generator=generator) * band_width
            - band_offset
        )
        # convert logits to probabilities
        probs = F.softmax(logits, dim=-1)

        # draw samples
        results = torch.multinomial(
            probs,
            n_draws,
            replacement=replacement,
            generator=generator,
        )
        all_results.append(results.unsqueeze(1))

    # combine the samples, shape is (n_draws, n_trials)
    results = torch.cat(all_results, dim=1)
    errors = []
    for nth_draw in range(n_draws):
        # get the nth draw across all trials
        nth_draw_across_trials = results[nth_draw]
        unique, counts = torch.unique(nth_draw_across_trials, return_counts=True)

        # We expect each draw to be identical across all trials.
        # which would mean a single unique index was selected across all trials.
        if len(unique) > 1:
            errors.append((nth_draw, unique))

    message = (
        f"There are {len(errors)} row(s) in the output that are not identical across all {n_trials} trials."
        " The following samples are misaligned:"
    )
    if len(errors) > 0:
        print(message)

        for nth_draw, unique in errors:
            print(f"Draw {nth_draw} was not identical across all trials.  {len(unique)} different indices were sampled across {n_trials} trials: {unique.cpu().numpy().tolist()}")
    else:
        print("All samples are identical across all trials.")


device = torch.device("cuda")
generator = torch.Generator(device)
band_width = 5
band_offset = -5
n_candidates = 1_000_000
n_draws = 1000
n_trials = 1000
replacement = True

run_test(
    device=device,
    n_candidates=n_candidates,
    n_draws=n_draws,
    n_trials=n_trials,
    replacement=replacement,
    band_width=band_width,
    band_offset=band_offset,
    generator=generator,

)

Example output (will change across runs):

There are 11 row(s) in the output that are not identical across all 1000 trials. The following samples are misaligned:
Draw 65 was not identical across all trials.  2 different indices were sampled across 1000 trials: [197680, 197681]
Draw 111 was not identical across all trials.  2 different indices were sampled across 1000 trials: [128305, 128306]
Draw 115 was not identical across all trials.  2 different indices were sampled across 1000 trials: [181477, 181478]
Draw 306 was not identical across all trials.  2 different indices were sampled across 1000 trials: [66297, 66298]
Draw 465 was not identical across all trials.  2 different indices were sampled across 1000 trials: [131393, 131394]
Draw 550 was not identical across all trials.  2 different indices were sampled across 1000 trials: [90228, 90229]
Draw 585 was not identical across all trials.  2 different indices were sampled across 1000 trials: [129377, 129378]
Draw 802 was not identical across all trials.  2 different indices were sampled across 1000 trials: [154692, 154693]
Draw 846 was not identical across all trials.  2 different indices were sampled across 1000 trials: [175321, 175322]
Draw 856 was not identical across all trials.  2 different indices were sampled across 1000 trials: [135070, 135071]
Draw 954 was not identical across all trials.  2 different indices were sampled across 1000 trials: [184852, 184853]

This seems to be related to this post so make sure you are enabling deterministic algorithms and are using a new PyTorch version.