Backward of a Binomial sample

Hi, I’d like to know if it is possible to use the sample of a Binomial (torch.distributions.binomial.Binomial(probs=probs).sample()) and get the gradients of the probs tensor via a backward operation. Namely:

input = torch.randn(1, 20)

weight = torch.nn.Parameter(torch.ones(5) * 3., requires_grad=True)

fc = nn.Linear(20, 5)

probs = torch.sigmoid(weight)
binomial = torch.distributions.binomial.Binomial(probs=probs)
mask = binomial.sample()

out = torch.einsum("ij,j->ij", fc(input), mask)

out.sum().backward()

print(weight.grad)
print(fc.weight.grad)

As of now, weight.grad is None, while fc.weight.grad is not. Also I’ve noticed that if I rescale mask the gradients are not longer None, e.g.

weight = torch.nn.Parameter(torch.ones(5) * 3., requires_grad=True)

fc = nn.Linear(20, 5)

probs = torch.sigmoid(weight)
binomial = torch.distributions.binomial.Binomial(probs=probs)
mask = (binomial.sample() * (1.0 / probs))  # Notice the rescaling

out = torch.einsum("ij,j->ij", fc(input), mask)

out.sum().backward()

print(weight.grad)
print(fc.weight.grad)

Any idea why this happens? Thank you.

Backpropagation through a random sample is tricky. One of the most common methods to rewrite your distribution as a differentiable function of parameterless random variables. This is the reparameterization trick used for variational auto encoders and similar (https://arxiv.org/pdf/1312.6114v10.pdf). For example, a sample from the normal distribution N(a,b) is just b*N(0,1) + a, where a and b are tunable parameters, and N(0,1) represents a sample from a normal distribution of mean 0 and variance 1.

In order to backpropagate through a sample, the distribution must implement one of these reparameterization methods in the form of distribution.rsample(), shown here Probability distributions - torch.distributions — PyTorch 1.10 documentation

Unfortunately, I don’t think the binomial distribution has rsample() implemented

Python 3.9.7 (default, Sep 16 2021, 13:09:58) 
Type 'copyright', 'credits' or 'license' for more information
IPython 7.29.0 -- An enhanced Interactive Python. Type '?' for help.
probs:  tensor([0.9526, 0.9526, 0.9526, 0.9526, 0.9526], grad_fn=<SigmoidBackward0>)
mask: tensor([1., 1., 1., 1., 1.])
weight.grad: None
fc.weight.grad: tensor([[ 1.3255,  0.0107,  0.8800, -1.5568,  1.7515,  0.6499,  0.1913, -0.1367,
          0.4323,  0.7505,  1.6673,  1.0194,  1.4906,  0.1311, -0.4088,  0.4096,
         -0.8983,  0.0185, -0.7355,  0.9395],
        [ 1.3255,  0.0107,  0.8800, -1.5568,  1.7515,  0.6499,  0.1913, -0.1367,
          0.4323,  0.7505,  1.6673,  1.0194,  1.4906,  0.1311, -0.4088,  0.4096,
         -0.8983,  0.0185, -0.7355,  0.9395],
        [ 1.3255,  0.0107,  0.8800, -1.5568,  1.7515,  0.6499,  0.1913, -0.1367,
          0.4323,  0.7505,  1.6673,  1.0194,  1.4906,  0.1311, -0.4088,  0.4096,
         -0.8983,  0.0185, -0.7355,  0.9395],
        [ 1.3255,  0.0107,  0.8800, -1.5568,  1.7515,  0.6499,  0.1913, -0.1367,
          0.4323,  0.7505,  1.6673,  1.0194,  1.4906,  0.1311, -0.4088,  0.4096,
         -0.8983,  0.0185, -0.7355,  0.9395],
        [ 1.3255,  0.0107,  0.8800, -1.5568,  1.7515,  0.6499,  0.1913, -0.1367,
          0.4323,  0.7505,  1.6673,  1.0194,  1.4906,  0.1311, -0.4088,  0.4096,
         -0.8983,  0.0185, -0.7355,  0.9395]])

In [1]: print(binomial.has_rsample)
False

In [2]: rs = binomial.rsample()
---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
<ipython-input-2-6374ad2cc45e> in <module>
----> 1 rs = binomial.rsample()

~/.conda/envs/pytorch/lib/python3.9/site-packages/torch/distributions/distribution.py in rsample(self, sample_shape)
    152         are batched.
    153         """
--> 154         raise NotImplementedError
    155 
    156     def sample_n(self, n):

NotImplementedError: 

In [3]: torch.__version__
Out[3]: '1.10.0'

In [4]: 

hope this helps!

1 Like