Backpropagation through a random sample is tricky. One of the most common methods to rewrite your distribution as a differentiable function of parameterless random variables. This is the reparameterization trick used for variational auto encoders and similar (https://arxiv.org/pdf/1312.6114v10.pdf). For example, a sample from the normal distribution N(a,b) is just b*N(0,1) + a, where a and b are tunable parameters, and N(0,1) represents a sample from a normal distribution of mean 0 and variance 1.
In order to backpropagate through a sample, the distribution must implement one of these reparameterization methods in the form of distribution.rsample()
, shown here Probability distributions - torch.distributions — PyTorch 1.10 documentation
Unfortunately, I don’t think the binomial distribution has rsample()
implemented
Python 3.9.7 (default, Sep 16 2021, 13:09:58)
Type 'copyright', 'credits' or 'license' for more information
IPython 7.29.0 -- An enhanced Interactive Python. Type '?' for help.
probs: tensor([0.9526, 0.9526, 0.9526, 0.9526, 0.9526], grad_fn=<SigmoidBackward0>)
mask: tensor([1., 1., 1., 1., 1.])
weight.grad: None
fc.weight.grad: tensor([[ 1.3255, 0.0107, 0.8800, -1.5568, 1.7515, 0.6499, 0.1913, -0.1367,
0.4323, 0.7505, 1.6673, 1.0194, 1.4906, 0.1311, -0.4088, 0.4096,
-0.8983, 0.0185, -0.7355, 0.9395],
[ 1.3255, 0.0107, 0.8800, -1.5568, 1.7515, 0.6499, 0.1913, -0.1367,
0.4323, 0.7505, 1.6673, 1.0194, 1.4906, 0.1311, -0.4088, 0.4096,
-0.8983, 0.0185, -0.7355, 0.9395],
[ 1.3255, 0.0107, 0.8800, -1.5568, 1.7515, 0.6499, 0.1913, -0.1367,
0.4323, 0.7505, 1.6673, 1.0194, 1.4906, 0.1311, -0.4088, 0.4096,
-0.8983, 0.0185, -0.7355, 0.9395],
[ 1.3255, 0.0107, 0.8800, -1.5568, 1.7515, 0.6499, 0.1913, -0.1367,
0.4323, 0.7505, 1.6673, 1.0194, 1.4906, 0.1311, -0.4088, 0.4096,
-0.8983, 0.0185, -0.7355, 0.9395],
[ 1.3255, 0.0107, 0.8800, -1.5568, 1.7515, 0.6499, 0.1913, -0.1367,
0.4323, 0.7505, 1.6673, 1.0194, 1.4906, 0.1311, -0.4088, 0.4096,
-0.8983, 0.0185, -0.7355, 0.9395]])
In [1]: print(binomial.has_rsample)
False
In [2]: rs = binomial.rsample()
---------------------------------------------------------------------------
NotImplementedError Traceback (most recent call last)
<ipython-input-2-6374ad2cc45e> in <module>
----> 1 rs = binomial.rsample()
~/.conda/envs/pytorch/lib/python3.9/site-packages/torch/distributions/distribution.py in rsample(self, sample_shape)
152 are batched.
153 """
--> 154 raise NotImplementedError
155
156 def sample_n(self, n):
NotImplementedError:
In [3]: torch.__version__
Out[3]: '1.10.0'
In [4]:
hope this helps!