Hi,
I was in need of performing a weighted random selection in pytorch and at the time I didn’t know about torch.multinomial, so I came up with my own implementation. I’ve since learned about torch.multinomial, but my implementation seems to be significantly faster. I haven’t done detailed performance testing, but some ad-hoc testing shows it to be consistently faster on my machine (cuda). I’d rather use my own implementation since it’s faster, but I’d like to get some feedback to make sure there isn’t an error in my implementation.
weights = ... # these are calculated somewhere else. Shape = (batchSize, n)
weights = weights.cumsum(-1)
probabilities = torch.rand((batchSize, 1)).expand(-1, probabilities.shape[-1])
delta = probabilities - weights
# I suspect there is a faster implementation for this op
# Assumption: Max value of cumsum of probabilities *should* be 1.0.
delta = torch.where(delta < 0, 2.0, delta)
choices = torch.argmin(delta, -1)
I will also give an example in plaintext:
Say we have probabilities = [0.3, 0.1, 0.4, 0.1, 0.1]
cumsum(probabilities) = [0.3, 0.4, 0.8, 0.9, 1.0]
Say we choose a random number r = 0.4
delta = r - cumsum(probabilities) = [0.1, 0, -0.4, -0.5, -0.6]
Then we account for the negatives:
delta = delta < 0 ? 2.0 : delta = [0.1, 0, 2.0, 2.0, 2.0]
Finally, we get the argmin of delta, which is where the value is 0 at index 1.