# Alternative to torch.multinomial

Hi,

I was in need of performing a weighted random selection in pytorch and at the time I didn’t know about torch.multinomial, so I came up with my own implementation. I’ve since learned about torch.multinomial, but my implementation seems to be significantly faster. I haven’t done detailed performance testing, but some ad-hoc testing shows it to be consistently faster on my machine (cuda). I’d rather use my own implementation since it’s faster, but I’d like to get some feedback to make sure there isn’t an error in my implementation.

``````weights = ... # these are calculated somewhere else. Shape = (batchSize, n)

weights = weights.cumsum(-1)
probabilities = torch.rand((batchSize, 1)).expand(-1, probabilities.shape[-1])
delta = probabilities - weights

# I suspect there is a faster implementation for this op
# Assumption: Max value of cumsum of probabilities *should* be 1.0.
delta = torch.where(delta < 0, 2.0, delta)

choices = torch.argmin(delta, -1)
``````

I will also give an example in plaintext:

Say we have probabilities = [0.3, 0.1, 0.4, 0.1, 0.1]
cumsum(probabilities) = [0.3, 0.4, 0.8, 0.9, 1.0]

Say we choose a random number r = 0.4
delta = r - cumsum(probabilities) = [0.1, 0, -0.4, -0.5, -0.6]

Then we account for the negatives:
delta = delta < 0 ? 2.0 : delta = [0.1, 0, 2.0, 2.0, 2.0]

Finally, we get the argmin of delta, which is where the value is 0 at index 1.

Hi Parham!

Your algorithm will never generate a `choices` value equal to the largest
desired value of `n - 1`. (In your example case, `n - 1 = 4`.)

In terms of your example, `torch.rand()` never generates a result greater
than or equal to `1.0`. If `rand()` generates a value between `0.9` and `1.0`,
`delta[4]` will initially be negative (say, `0.95 - 1.0`) and will then be replaced
by `2.0`. So the `argmin()` that gives you the `choices` value will be `3` (and
never `4`).

If you wish to follow up, please post a fully-self-contained, runnable script
with your proposed algorithm and a test case or two, together with the output
you get when you run it.

Best.

K. Frank