Sample a tensor of probability distributions

I want to sample a tensor of probability distributions with shape (N, C, H, W), where dimension 1 (size C) contains normalized probability distributions with ‘C’ possibilities. Is there a way to efficiently sample all the distributions in the tensor in parallel? I just need to sample each distribution once, so the result could either be a one-hot tensor with the same shape or a tensor of indices with shape (N, 1, H, W).

Hi Learned!

Yes, you can use torch.distributions.Categorical, provided you
adjust your distributions tensor so that its last dimension is the distribution
dimension.

Here is an example script:

import torch
print (torch.__version__)

_ = torch.random.manual_seed (2021)

N = 2
C = 3
H = 5
W = 7

probs = torch.randn (N, C, H, W).softmax (1)
print ('probs = ...')
print (probs)
print ('probs.sum (1) = ...')
print (probs.sum (1))

sample = torch.distributions.Categorical (probs = probs.transpose (1, -1)).sample().transpose (-1, -2).unsqueeze (1)

print ('sample.shape =', sample.shape)
print ('sample = ...')
print (sample)

And here is its output:

1.7.1
probs = ...
tensor([[[[0.1498, 0.3152, 0.2946, 0.6541, 0.3106, 0.4475, 0.3918],
          [0.1289, 0.2494, 0.5813, 0.1555, 0.2688, 0.1649, 0.6196],
          [0.1607, 0.7599, 0.2339, 0.3343, 0.6459, 0.7187, 0.5310],
          [0.2014, 0.0938, 0.2341, 0.8172, 0.3617, 0.0953, 0.6246],
          [0.8510, 0.1427, 0.0091, 0.1163, 0.2765, 0.6657, 0.2254]],

         [[0.7174, 0.1177, 0.1747, 0.1609, 0.3015, 0.0444, 0.2602],
          [0.1545, 0.5129, 0.2338, 0.4810, 0.2133, 0.6208, 0.1486],
          [0.3673, 0.0383, 0.2041, 0.4826, 0.0756, 0.1309, 0.2405],
          [0.4219, 0.5621, 0.0419, 0.0825, 0.4854, 0.4959, 0.0707],
          [0.1043, 0.7390, 0.1671, 0.5642, 0.5226, 0.3112, 0.3942]],

         [[0.1329, 0.5671, 0.5306, 0.1850, 0.3879, 0.5082, 0.3480],
          [0.7167, 0.2377, 0.1849, 0.3635, 0.5179, 0.2143, 0.2318],
          [0.4720, 0.2018, 0.5620, 0.1831, 0.2785, 0.1503, 0.2285],
          [0.3767, 0.3441, 0.7239, 0.1003, 0.1529, 0.4088, 0.3047],
          [0.0447, 0.1183, 0.8238, 0.3194, 0.2009, 0.0231, 0.3803]]],


        [[[0.6440, 0.1537, 0.0505, 0.0511, 0.0996, 0.1050, 0.4653],
          [0.1242, 0.2676, 0.6757, 0.1266, 0.6718, 0.2993, 0.0868],
          [0.7833, 0.4048, 0.6902, 0.2550, 0.2607, 0.1759, 0.1606],
          [0.1922, 0.3755, 0.6223, 0.2364, 0.3413, 0.9021, 0.5981],
          [0.2017, 0.5419, 0.5284, 0.3065, 0.4233, 0.1412, 0.2183]],

         [[0.3134, 0.2802, 0.6204, 0.7494, 0.3884, 0.0774, 0.4969],
          [0.1248, 0.6669, 0.1558, 0.2342, 0.0883, 0.0252, 0.8172],
          [0.1465, 0.3188, 0.0329, 0.6245, 0.6833, 0.2322, 0.1315],
          [0.4668, 0.2589, 0.2702, 0.0258, 0.3919, 0.0188, 0.1836],
          [0.3882, 0.3065, 0.2767, 0.0930, 0.1194, 0.4706, 0.0861]],

         [[0.0425, 0.5662, 0.3291, 0.1995, 0.5120, 0.8176, 0.0378],
          [0.7510, 0.0655, 0.1685, 0.6392, 0.2399, 0.6755, 0.0960],
          [0.0702, 0.2764, 0.2768, 0.1205, 0.0560, 0.5918, 0.7079],
          [0.3410, 0.3655, 0.1075, 0.7378, 0.2668, 0.0791, 0.2184],
          [0.4101, 0.1517, 0.1949, 0.6006, 0.4573, 0.3881, 0.6956]]]])
probs.sum (1) = ...
tensor([[[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000]],

        [[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000]]])
sample.shape = torch.Size([2, 1, 5, 7])
sample = ...
tensor([[[[1, 0, 2, 0, 0, 0, 2],
          [1, 0, 2, 2, 2, 1, 2],
          [2, 0, 0, 0, 0, 0, 0],
          [1, 1, 0, 0, 1, 2, 0],
          [0, 1, 2, 2, 1, 0, 1]]],


        [[[0, 2, 1, 1, 1, 1, 1],
          [0, 1, 0, 2, 0, 0, 1],
          [2, 2, 0, 0, 1, 2, 2],
          [2, 2, 0, 2, 1, 0, 0],
          [1, 1, 2, 2, 0, 0, 2]]]])

Best.

K. Frank

1 Like