I want to sample a tensor of probability distributions with shape (N, C, H, W), where dimension 1 (size C) contains normalized probability distributions with ‘C’ possibilities. Is there a way to efficiently sample all the distributions in the tensor in parallel? I just need to sample each distribution once, so the result could either be a one-hot tensor with the same shape or a tensor of indices with shape (N, 1, H, W).
Hi Learned!
Yes, you can use torch.distributions.Categorical
, provided you
adjust your distributions tensor so that its last dimension is the distribution
dimension.
Here is an example script:
import torch
print (torch.__version__)
_ = torch.random.manual_seed (2021)
N = 2
C = 3
H = 5
W = 7
probs = torch.randn (N, C, H, W).softmax (1)
print ('probs = ...')
print (probs)
print ('probs.sum (1) = ...')
print (probs.sum (1))
sample = torch.distributions.Categorical (probs = probs.transpose (1, -1)).sample().transpose (-1, -2).unsqueeze (1)
print ('sample.shape =', sample.shape)
print ('sample = ...')
print (sample)
And here is its output:
1.7.1
probs = ...
tensor([[[[0.1498, 0.3152, 0.2946, 0.6541, 0.3106, 0.4475, 0.3918],
[0.1289, 0.2494, 0.5813, 0.1555, 0.2688, 0.1649, 0.6196],
[0.1607, 0.7599, 0.2339, 0.3343, 0.6459, 0.7187, 0.5310],
[0.2014, 0.0938, 0.2341, 0.8172, 0.3617, 0.0953, 0.6246],
[0.8510, 0.1427, 0.0091, 0.1163, 0.2765, 0.6657, 0.2254]],
[[0.7174, 0.1177, 0.1747, 0.1609, 0.3015, 0.0444, 0.2602],
[0.1545, 0.5129, 0.2338, 0.4810, 0.2133, 0.6208, 0.1486],
[0.3673, 0.0383, 0.2041, 0.4826, 0.0756, 0.1309, 0.2405],
[0.4219, 0.5621, 0.0419, 0.0825, 0.4854, 0.4959, 0.0707],
[0.1043, 0.7390, 0.1671, 0.5642, 0.5226, 0.3112, 0.3942]],
[[0.1329, 0.5671, 0.5306, 0.1850, 0.3879, 0.5082, 0.3480],
[0.7167, 0.2377, 0.1849, 0.3635, 0.5179, 0.2143, 0.2318],
[0.4720, 0.2018, 0.5620, 0.1831, 0.2785, 0.1503, 0.2285],
[0.3767, 0.3441, 0.7239, 0.1003, 0.1529, 0.4088, 0.3047],
[0.0447, 0.1183, 0.8238, 0.3194, 0.2009, 0.0231, 0.3803]]],
[[[0.6440, 0.1537, 0.0505, 0.0511, 0.0996, 0.1050, 0.4653],
[0.1242, 0.2676, 0.6757, 0.1266, 0.6718, 0.2993, 0.0868],
[0.7833, 0.4048, 0.6902, 0.2550, 0.2607, 0.1759, 0.1606],
[0.1922, 0.3755, 0.6223, 0.2364, 0.3413, 0.9021, 0.5981],
[0.2017, 0.5419, 0.5284, 0.3065, 0.4233, 0.1412, 0.2183]],
[[0.3134, 0.2802, 0.6204, 0.7494, 0.3884, 0.0774, 0.4969],
[0.1248, 0.6669, 0.1558, 0.2342, 0.0883, 0.0252, 0.8172],
[0.1465, 0.3188, 0.0329, 0.6245, 0.6833, 0.2322, 0.1315],
[0.4668, 0.2589, 0.2702, 0.0258, 0.3919, 0.0188, 0.1836],
[0.3882, 0.3065, 0.2767, 0.0930, 0.1194, 0.4706, 0.0861]],
[[0.0425, 0.5662, 0.3291, 0.1995, 0.5120, 0.8176, 0.0378],
[0.7510, 0.0655, 0.1685, 0.6392, 0.2399, 0.6755, 0.0960],
[0.0702, 0.2764, 0.2768, 0.1205, 0.0560, 0.5918, 0.7079],
[0.3410, 0.3655, 0.1075, 0.7378, 0.2668, 0.0791, 0.2184],
[0.4101, 0.1517, 0.1949, 0.6006, 0.4573, 0.3881, 0.6956]]]])
probs.sum (1) = ...
tensor([[[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000]],
[[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000]]])
sample.shape = torch.Size([2, 1, 5, 7])
sample = ...
tensor([[[[1, 0, 2, 0, 0, 0, 2],
[1, 0, 2, 2, 2, 1, 2],
[2, 0, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 1, 2, 0],
[0, 1, 2, 2, 1, 0, 1]]],
[[[0, 2, 1, 1, 1, 1, 1],
[0, 1, 0, 2, 0, 0, 1],
[2, 2, 0, 0, 1, 2, 2],
[2, 2, 0, 2, 1, 0, 0],
[1, 1, 2, 2, 0, 0, 2]]]])
Best.
K. Frank
1 Like