Hi Yoda!
masteryoda:
P(x_mnij=0) = w_mnij0
P(x_mnij=1) = w_mnij1
P(x_mnij=2) = w_mnij2
As I understand your question, you are asking for the categorical
distribution (which is a special case of the multinomial distribution).
torch.distributions.Categorical
take batches of probs
(as does
torch.distributions.Multinomial
), so you can just pass your probs
in as-is.
Here is an example (where I’ve dropped your J
dimension for
convenience):
>>> import torch
>>> torch.__version__
'1.9.0'
>>> _ = torch.manual_seed (2021)
>>> M = 2
>>> N = 5
>>> I = 7
>>> probs = torch.rand (M, N, I, 3)
>>> probs
tensor([[[[0.1304, 0.5134, 0.7426],
[0.7159, 0.5705, 0.1653],
[0.0443, 0.9628, 0.2943],
[0.0992, 0.8096, 0.0169],
[0.8222, 0.1242, 0.7489],
[0.3608, 0.5131, 0.2959],
[0.7834, 0.7405, 0.8050]],
[[0.3036, 0.9942, 0.5025],
[0.3734, 0.0413, 0.8387],
[0.0604, 0.1773, 0.3301],
[0.6857, 0.6960, 0.8303],
[0.5216, 0.7438, 0.8290],
[0.0219, 0.0813, 0.0172],
[0.1464, 0.7492, 0.9450]],
[[0.6737, 0.1135, 0.7421],
[0.7810, 0.9446, 0.0451],
[0.4282, 0.2427, 0.9363],
[0.7784, 0.5605, 0.5312],
[0.7132, 0.1075, 0.4496],
[0.1255, 0.6784, 0.6550],
[0.0650, 0.4786, 0.7321]],
[[0.7520, 0.9717, 0.0492],
[0.3442, 0.2955, 0.0772],
[0.0228, 0.1315, 0.7325],
[0.0736, 0.2712, 0.1656],
[0.1152, 0.9514, 0.4467],
[0.3331, 0.9558, 0.2115],
[0.2883, 0.3373, 0.1650]],
[[0.2607, 0.3795, 0.5373],
[0.2057, 0.4805, 0.9925],
[0.5080, 0.4415, 0.0078],
[0.1357, 0.7547, 0.7994],
[0.0285, 0.6075, 0.6055],
[0.3037, 0.8905, 0.5738],
[0.1186, 0.9209, 0.7311]]],
[[[0.2627, 0.5386, 0.8638],
[0.8822, 0.7373, 0.6365],
[0.4724, 0.5814, 0.1358],
[0.5071, 0.0114, 0.7892],
[0.1523, 0.1646, 0.9006],
[0.3421, 0.1729, 0.0969],
[0.7246, 0.0545, 0.8811]],
[[0.6582, 0.9686, 0.3049],
[0.0658, 0.4522, 0.8983],
[0.2619, 0.0371, 0.8865],
[0.8534, 0.9061, 0.8211],
[0.5210, 0.0960, 0.1545],
[0.2690, 0.1763, 0.1120],
[0.8787, 0.8737, 0.1782]],
[[0.4688, 0.9811, 0.7441],
[0.3973, 0.3600, 0.7546],
[0.3047, 0.1206, 0.3601],
[0.8179, 0.2981, 0.1501],
[0.3135, 0.3079, 0.1304],
[0.4902, 0.7529, 0.9662],
[0.6364, 0.8964, 0.6964]],
[[0.3709, 0.1203, 0.3698],
[0.7380, 0.6052, 0.3663],
[0.4788, 0.6704, 0.4302],
[0.8689, 0.1824, 0.5595],
[0.7591, 0.8993, 0.5691],
[0.5340, 0.9363, 0.1453],
[0.2639, 0.2078, 0.9786]],
[[0.5227, 0.0941, 0.5573],
[0.0276, 0.7732, 0.8916],
[0.7088, 0.3952, 0.2207],
[0.9375, 0.4134, 0.7030],
[0.4679, 0.9109, 0.4135],
[0.8190, 0.1285, 0.0799],
[0.5837, 0.9205, 0.1744]]]])
>>> torch.distributions.Categorical (probs).sample()
tensor([[[2, 0, 1, 1, 2, 0, 2],
[1, 0, 1, 2, 1, 1, 1],
[0, 0, 0, 0, 2, 2, 1],
[0, 1, 2, 2, 0, 1, 1],
[0, 2, 0, 2, 1, 1, 1]],
[[2, 2, 0, 2, 2, 0, 0],
[1, 1, 0, 2, 2, 1, 1],
[1, 2, 2, 0, 0, 2, 0],
[0, 2, 0, 0, 0, 1, 2],
[1, 1, 0, 1, 0, 0, 0]]])
Note that a given row of three probabilities doesn’t have to sum to one;
Categorical
normalizes them for you.
Best.
K. Frank