# Backward through Bernoulli sample

I am trying to backpropagate through a Bernoulli sample. A simplified version of my code is:

``````...
parameter = torch.nn.Parameter(torch.ones(8), requires_grad=True) # parameters to learn
prob = torch.sigmoid(parameter) #sigmoid to make a probability [0,1]
bern_sample = torch.distributions.bernoulli.Bernoulli(prob).sample() # draw a sample

loss = torch.sum(bern_sample)
loss.backward()
``````

Intuitively the variable `parameter` should be updated by gradient descent in order to make all probabilities close to 0, i.e. `bern_sample` should contain all zeros.

However when I check gradients wrt to parameter, i.e. `parameter.grad`, they are None.

What am I doing wrong?

Hello Ruggero!

Sampling from a discrete distribution isn’t differentiable. So pytorch’s
subsequent functions back through the Bernoulli-sampling step.

The value returned from sampling a Bernoulli distribution (or any
discrete distribution) is a discrete value – `0.0` or `1.0`. So you can’t
differentiate it. That is, you can’t make your sample value a little
smaller, say `0.99` instead of `1.0`, by making your parameter `prob`
a little smaller.

Best.

K. Frank

1 Like

Hello K Frank,
I’m just wondering if sampling from other distribution is a differentiable operation or not?

Hi Anurag!

I do not believe that pytorch supports differentiating a sample from a
(continuous) distribution with respect to parameters describing that
distribution. (But I haven’t checked.)

However, doing so can make mathematical sense, and you can
sometimes structure your calculations so that it works.

Consider a gaussian distribution with mean `mu`, and standard
deviation `sigma`. You can sample from such a distribution by
first sampling a value `z` from a normal distribution (mean = 0,
standard deviation = 1), and then transforming that value:

``````x = sigma * z + mu
``````

If `mu` and `sigma` are packaged as `Tensor`s that have
`requires_grad = True`, then calling `x.backward()` will
compute gradients for `sigma` and `mu`.

(Whether you will get the result you expect depends on what you
were expecting, but this scheme is conceptually self-consistent.)

Best.

K. Frank

I solved the problem by using the reparametrization trick with Gumbel Softmax. I took inspiration from this: https://gist.github.com/yzh119/fd2146d2aeb329d067568a493b20172f