Sample from the multinomial probability distribution using torch.multinomial

I have a tensor with multinomial probabilities. Each vector (in my example below is 1x3) inside the tensor represent multinomial probability distribution for 1 random variable:
I want to sample from the multinomial probability distribution tensor of random variables.

For example, this is the tensor (MxNxIxJx3) with the multinomial probabilities:

[[w_00ij0, w_00ij1, w_00ij2], … [w_0Nij0, w_0Nij1, w_0Nij2],
… …
[w_M0ij0, w_M0ij1, w_M0ij2], … [w_MNij0, w_MNij1, w_MNij2]]

This the tensor I want to sample (MxNxIxJ) from the above distribution:

[[x_00ij], … [x_0Nij,],

[x_M0ij], … [x_MNij]]

when (for example):

P(x_mnij=0) = w_mnij0
P(x_mnij=1) = w_mnij1
P(x_mnij=2) = w_mnij2

I couldn’t find a way to do it fast (without running all over the tensor).

Hi Yoda!

As I understand your question, you are asking for the categorical
distribution (which is a special case of the multinomial distribution).

torch.distributions.Categorical take batches of probs (as does
torch.distributions.Multinomial), so you can just pass your probs
in as-is.

Here is an example (where I’ve dropped your J dimension for
convenience):

>>> import torch
>>> torch.__version__
'1.9.0'
>>> _ = torch.manual_seed (2021)
>>> M = 2
>>> N = 5
>>> I = 7
>>> probs = torch.rand (M, N, I, 3)
>>> probs
tensor([[[[0.1304, 0.5134, 0.7426],
          [0.7159, 0.5705, 0.1653],
          [0.0443, 0.9628, 0.2943],
          [0.0992, 0.8096, 0.0169],
          [0.8222, 0.1242, 0.7489],
          [0.3608, 0.5131, 0.2959],
          [0.7834, 0.7405, 0.8050]],

         [[0.3036, 0.9942, 0.5025],
          [0.3734, 0.0413, 0.8387],
          [0.0604, 0.1773, 0.3301],
          [0.6857, 0.6960, 0.8303],
          [0.5216, 0.7438, 0.8290],
          [0.0219, 0.0813, 0.0172],
          [0.1464, 0.7492, 0.9450]],

         [[0.6737, 0.1135, 0.7421],
          [0.7810, 0.9446, 0.0451],
          [0.4282, 0.2427, 0.9363],
          [0.7784, 0.5605, 0.5312],
          [0.7132, 0.1075, 0.4496],
          [0.1255, 0.6784, 0.6550],
          [0.0650, 0.4786, 0.7321]],

         [[0.7520, 0.9717, 0.0492],
          [0.3442, 0.2955, 0.0772],
          [0.0228, 0.1315, 0.7325],
          [0.0736, 0.2712, 0.1656],
          [0.1152, 0.9514, 0.4467],
          [0.3331, 0.9558, 0.2115],
          [0.2883, 0.3373, 0.1650]],

         [[0.2607, 0.3795, 0.5373],
          [0.2057, 0.4805, 0.9925],
          [0.5080, 0.4415, 0.0078],
          [0.1357, 0.7547, 0.7994],
          [0.0285, 0.6075, 0.6055],
          [0.3037, 0.8905, 0.5738],
          [0.1186, 0.9209, 0.7311]]],


        [[[0.2627, 0.5386, 0.8638],
          [0.8822, 0.7373, 0.6365],
          [0.4724, 0.5814, 0.1358],
          [0.5071, 0.0114, 0.7892],
          [0.1523, 0.1646, 0.9006],
          [0.3421, 0.1729, 0.0969],
          [0.7246, 0.0545, 0.8811]],

         [[0.6582, 0.9686, 0.3049],
          [0.0658, 0.4522, 0.8983],
          [0.2619, 0.0371, 0.8865],
          [0.8534, 0.9061, 0.8211],
          [0.5210, 0.0960, 0.1545],
          [0.2690, 0.1763, 0.1120],
          [0.8787, 0.8737, 0.1782]],

         [[0.4688, 0.9811, 0.7441],
          [0.3973, 0.3600, 0.7546],
          [0.3047, 0.1206, 0.3601],
          [0.8179, 0.2981, 0.1501],
          [0.3135, 0.3079, 0.1304],
          [0.4902, 0.7529, 0.9662],
          [0.6364, 0.8964, 0.6964]],

         [[0.3709, 0.1203, 0.3698],
          [0.7380, 0.6052, 0.3663],
          [0.4788, 0.6704, 0.4302],
          [0.8689, 0.1824, 0.5595],
          [0.7591, 0.8993, 0.5691],
          [0.5340, 0.9363, 0.1453],
          [0.2639, 0.2078, 0.9786]],

         [[0.5227, 0.0941, 0.5573],
          [0.0276, 0.7732, 0.8916],
          [0.7088, 0.3952, 0.2207],
          [0.9375, 0.4134, 0.7030],
          [0.4679, 0.9109, 0.4135],
          [0.8190, 0.1285, 0.0799],
          [0.5837, 0.9205, 0.1744]]]])
>>> torch.distributions.Categorical (probs).sample()
tensor([[[2, 0, 1, 1, 2, 0, 2],
         [1, 0, 1, 2, 1, 1, 1],
         [0, 0, 0, 0, 2, 2, 1],
         [0, 1, 2, 2, 0, 1, 1],
         [0, 2, 0, 2, 1, 1, 1]],

        [[2, 2, 0, 2, 2, 0, 0],
         [1, 1, 0, 2, 2, 1, 1],
         [1, 2, 2, 0, 0, 2, 0],
         [0, 2, 0, 0, 0, 1, 2],
         [1, 1, 0, 1, 0, 0, 0]]])

Note that a given row of three probabilities doesn’t have to sum to one;
Categorical normalizes them for you.

Best.

K. Frank

1 Like