# Backward of a Binomial sample

Hi, I’d like to know if it is possible to use the sample of a Binomial (`torch.distributions.binomial.Binomial(probs=probs).sample()`) and get the gradients of the `probs` tensor via a backward operation. Namely:

``````input = torch.randn(1, 20)

weight = torch.nn.Parameter(torch.ones(5) * 3., requires_grad=True)

fc = nn.Linear(20, 5)

probs = torch.sigmoid(weight)
binomial = torch.distributions.binomial.Binomial(probs=probs)

out.sum().backward()

``````

As of now, `weight.grad` is None, while `fc.weight.grad` is not. Also I’ve noticed that if I rescale `mask` the gradients are not longer None, e.g.

``````weight = torch.nn.Parameter(torch.ones(5) * 3., requires_grad=True)

fc = nn.Linear(20, 5)

probs = torch.sigmoid(weight)
binomial = torch.distributions.binomial.Binomial(probs=probs)
mask = (binomial.sample() * (1.0 / probs))  # Notice the rescaling

out.sum().backward()

``````

Any idea why this happens? Thank you.

Backpropagation through a random sample is tricky. One of the most common methods to rewrite your distribution as a differentiable function of parameterless random variables. This is the reparameterization trick used for variational auto encoders and similar (https://arxiv.org/pdf/1312.6114v10.pdf). For example, a sample from the normal distribution N(a,b) is just b*N(0,1) + a, where a and b are tunable parameters, and N(0,1) represents a sample from a normal distribution of mean 0 and variance 1.

In order to backpropagate through a sample, the distribution must implement one of these reparameterization methods in the form of `distribution.rsample()`, shown here Probability distributions - torch.distributions — PyTorch 1.10 documentation

Unfortunately, I don’t think the binomial distribution has `rsample()` implemented

``````Python 3.9.7 (default, Sep 16 2021, 13:09:58)
IPython 7.29.0 -- An enhanced Interactive Python. Type '?' for help.
probs:  tensor([0.9526, 0.9526, 0.9526, 0.9526, 0.9526], grad_fn=<SigmoidBackward0>)
mask: tensor([1., 1., 1., 1., 1.])
fc.weight.grad: tensor([[ 1.3255,  0.0107,  0.8800, -1.5568,  1.7515,  0.6499,  0.1913, -0.1367,
0.4323,  0.7505,  1.6673,  1.0194,  1.4906,  0.1311, -0.4088,  0.4096,
-0.8983,  0.0185, -0.7355,  0.9395],
[ 1.3255,  0.0107,  0.8800, -1.5568,  1.7515,  0.6499,  0.1913, -0.1367,
0.4323,  0.7505,  1.6673,  1.0194,  1.4906,  0.1311, -0.4088,  0.4096,
-0.8983,  0.0185, -0.7355,  0.9395],
[ 1.3255,  0.0107,  0.8800, -1.5568,  1.7515,  0.6499,  0.1913, -0.1367,
0.4323,  0.7505,  1.6673,  1.0194,  1.4906,  0.1311, -0.4088,  0.4096,
-0.8983,  0.0185, -0.7355,  0.9395],
[ 1.3255,  0.0107,  0.8800, -1.5568,  1.7515,  0.6499,  0.1913, -0.1367,
0.4323,  0.7505,  1.6673,  1.0194,  1.4906,  0.1311, -0.4088,  0.4096,
-0.8983,  0.0185, -0.7355,  0.9395],
[ 1.3255,  0.0107,  0.8800, -1.5568,  1.7515,  0.6499,  0.1913, -0.1367,
0.4323,  0.7505,  1.6673,  1.0194,  1.4906,  0.1311, -0.4088,  0.4096,
-0.8983,  0.0185, -0.7355,  0.9395]])

In [1]: print(binomial.has_rsample)
False

In [2]: rs = binomial.rsample()
---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
----> 1 rs = binomial.rsample()

~/.conda/envs/pytorch/lib/python3.9/site-packages/torch/distributions/distribution.py in rsample(self, sample_shape)
152         are batched.
153         """
--> 154         raise NotImplementedError
155
156     def sample_n(self, n):

NotImplementedError:

In [3]: torch.__version__
Out[3]: '1.10.0'

In [4]:
``````

hope this helps!

1 Like